!19605 add federated learning albert
Merge pull request !19605 from wtcheng/master
This commit is contained in:
commit
6863d52646
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from time import time
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from src.config import eval_cfg, server_net_cfg
|
||||
from src.dataset import load_datasets
|
||||
from src.utils import restore_params
|
||||
from src.model import AlbertModelCLS
|
||||
from src.tokenization import CustomizedTextTokenizer
|
||||
from src.assessment_method import Accuracy
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='server eval task')
|
||||
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU'])
|
||||
parser.add_argument('--device_id', type=str, default='0')
|
||||
parser.add_argument('--tokenizer_dir', type=str, default='../model_save/init/')
|
||||
parser.add_argument('--eval_data_dir', type=str, default='../datasets/eval/')
|
||||
parser.add_argument('--model_path', type=str, default='../model_save/train_server/0.ckpt')
|
||||
parser.add_argument('--vocab_map_ids_path', type=str, default='../model_save/init/vocab_map_ids.txt')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def server_eval(args):
|
||||
start = time()
|
||||
# some parameters
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
|
||||
tokenizer_dir = args.tokenizer_dir
|
||||
eval_data_dir = args.eval_data_dir
|
||||
model_path = args.model_path
|
||||
vocab_map_ids_path = args.vocab_map_ids_path
|
||||
|
||||
# mindspore context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
print('Context setting is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# data process
|
||||
tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
|
||||
datasets_list, _ = load_datasets(
|
||||
eval_data_dir, server_net_cfg.seq_length, tokenizer, eval_cfg.batch_size,
|
||||
label_list=None,
|
||||
do_shuffle=False,
|
||||
drop_remainder=False,
|
||||
output_dir=None)
|
||||
print('Data process is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# main model
|
||||
albert_model_cls = AlbertModelCLS(server_net_cfg)
|
||||
albert_model_cls.set_train(False)
|
||||
param_dict = load_checkpoint(model_path)
|
||||
restore_params(albert_model_cls, param_dict)
|
||||
print('Model construction is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# eval
|
||||
callback = Accuracy()
|
||||
global_step = 0
|
||||
for datasets in datasets_list:
|
||||
for batch in datasets.create_tuple_iterator():
|
||||
input_ids, attention_mask, token_type_ids, label_ids, _ = batch
|
||||
logits = albert_model_cls(input_ids, attention_mask, token_type_ids)
|
||||
callback.update(logits, label_ids)
|
||||
print('eval step: {}, {}: {}'.format(global_step, callback.name, callback.get_metrics()))
|
||||
sys.stdout.flush()
|
||||
global_step += 1
|
||||
metrics = callback.get_metrics()
|
||||
print('Final {}: {}'.format(callback.name, metrics))
|
||||
sys.stdout.flush()
|
||||
print('Evaluating process is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = parse_args()
|
||||
server_eval(args_opt)
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from time import time
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
from src.adam import AdamWeightDecayOp as AdamWeightDecay
|
||||
from src.tokenization import CustomizedTextTokenizer
|
||||
from src.config import train_cfg, server_net_cfg
|
||||
from src.dataset import load_dataset
|
||||
from src.utils import restore_params
|
||||
from src.model import AlbertModelCLS
|
||||
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='server task')
|
||||
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU'])
|
||||
parser.add_argument('--device_id', type=str, default='0')
|
||||
parser.add_argument('--tokenizer_dir', type=str, default='../model_save/init/')
|
||||
parser.add_argument('--server_data_path', type=str, default='../datasets/semi_supervise/server/train.txt')
|
||||
parser.add_argument('--model_path', type=str, default='../model_save/init/albert_init.ckpt')
|
||||
parser.add_argument('--output_dir', type=str, default='../model_save/train_server/')
|
||||
parser.add_argument('--vocab_map_ids_path', type=str, default='../model_save/init/vocab_map_ids.txt')
|
||||
parser.add_argument('--logging_step', type=int, default=1)
|
||||
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--ms_role", type=str, default="MS_WORKER")
|
||||
parser.add_argument("--worker_num", type=int, default=0)
|
||||
parser.add_argument("--server_num", type=int, default=1)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=0.05)
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def server_train(args):
|
||||
start = time()
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
|
||||
tokenizer_dir = args.tokenizer_dir
|
||||
server_data_path = args.server_data_path
|
||||
model_path = args.model_path
|
||||
output_dir = args.output_dir
|
||||
vocab_map_ids_path = args.vocab_map_ids_path
|
||||
logging_step = args.logging_step
|
||||
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
ms_role = args.ms_role
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
scheduler_manage_port = args.scheduler_manage_port
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
encrypt_type = args.encrypt_type
|
||||
|
||||
# Replace some parameters with federated learning parameters.
|
||||
train_cfg.max_global_epoch = fl_iteration_num
|
||||
|
||||
fl_ctx = {
|
||||
"enable_fl": True,
|
||||
"server_mode": server_mode,
|
||||
"ms_role": ms_role,
|
||||
"worker_num": worker_num,
|
||||
"server_num": server_num,
|
||||
"scheduler_ip": scheduler_ip,
|
||||
"scheduler_port": scheduler_port,
|
||||
"fl_server_port": fl_server_port,
|
||||
"start_fl_job_threshold": start_fl_job_threshold,
|
||||
"start_fl_job_time_window": start_fl_job_time_window,
|
||||
"update_model_ratio": update_model_ratio,
|
||||
"update_model_time_window": update_model_time_window,
|
||||
"fl_name": fl_name,
|
||||
"fl_iteration_num": fl_iteration_num,
|
||||
"client_epoch_num": client_epoch_num,
|
||||
"client_batch_size": client_batch_size,
|
||||
"client_learning_rate": client_learning_rate,
|
||||
"scheduler_manage_port": scheduler_manage_port,
|
||||
"dp_delta": dp_delta,
|
||||
"dp_norm_clip": dp_norm_clip,
|
||||
"encrypt_type": encrypt_type
|
||||
}
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# construct tokenizer
|
||||
tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
|
||||
print('Tokenizer construction is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# mindspore context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True)
|
||||
context.set_fl_context(**fl_ctx)
|
||||
print('Context setting is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# construct model
|
||||
albert_model_cls = AlbertModelCLS(server_net_cfg)
|
||||
network_with_cls_loss = NetworkWithCLSLoss(albert_model_cls)
|
||||
network_with_cls_loss.set_train(True)
|
||||
|
||||
print('Model construction is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# train prepare
|
||||
global_step = 0
|
||||
param_dict = load_checkpoint(model_path)
|
||||
if 'learning_rate' in param_dict:
|
||||
del param_dict['learning_rate']
|
||||
|
||||
# server optimizer
|
||||
server_params = [_ for _ in network_with_cls_loss.trainable_params()
|
||||
if 'word_embeddings' not in _.name
|
||||
and 'postprocessor' not in _.name]
|
||||
server_decay_params = list(
|
||||
filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params)
|
||||
)
|
||||
server_other_params = list(
|
||||
filter(lambda x: not train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter(x), server_params)
|
||||
)
|
||||
server_group_params = [
|
||||
{'params': server_decay_params, 'weight_decay': train_cfg.optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': server_other_params, 'weight_decay': 0.0},
|
||||
{'order_params': server_params}
|
||||
]
|
||||
server_optimizer = AdamWeightDecay(server_group_params,
|
||||
learning_rate=train_cfg.server_cfg.learning_rate,
|
||||
eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps)
|
||||
server_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=server_optimizer)
|
||||
|
||||
restore_params(server_network_train_cell, param_dict)
|
||||
|
||||
print('Optimizer construction is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# server load data
|
||||
server_train_dataset, _ = load_dataset(
|
||||
server_data_path, server_net_cfg.seq_length, tokenizer, train_cfg.batch_size,
|
||||
label_list=None,
|
||||
do_shuffle=True,
|
||||
drop_remainder=True,
|
||||
output_dir=None,
|
||||
cyclic_trunc=train_cfg.server_cfg.cyclic_trunc
|
||||
)
|
||||
print('Server data loading is done! Time cost: {}'.format(time() - start))
|
||||
start = time()
|
||||
|
||||
# train process
|
||||
for global_epoch in range(train_cfg.max_global_epoch):
|
||||
for server_local_epoch in range(train_cfg.server_cfg.max_local_epoch):
|
||||
for server_step, server_batch in enumerate(server_train_dataset.create_tuple_iterator()):
|
||||
input_ids, attention_mask, token_type_ids, label_ids, _ = server_batch
|
||||
model_start_time = time()
|
||||
cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids)
|
||||
time_cost = time() - model_start_time
|
||||
if global_step % logging_step == 0:
|
||||
print_text = 'server: '
|
||||
print_text += 'global_epoch {}/{} '.format(global_epoch, train_cfg.max_global_epoch)
|
||||
print_text += 'local_epoch {}/{} '.format(server_local_epoch, train_cfg.server_cfg.max_local_epoch)
|
||||
print_text += 'local_step {}/{} '.format(server_step, server_train_dataset.get_dataset_size())
|
||||
print_text += 'global_step {} cls_loss {} time_cost {}'.format(global_step, cls_loss, time_cost)
|
||||
print(print_text)
|
||||
sys.stdout.flush()
|
||||
global_step += 1
|
||||
del input_ids, attention_mask, token_type_ids, label_ids, _, cls_loss
|
||||
output_path = os.path.join(
|
||||
output_dir,
|
||||
str(global_epoch*train_cfg.server_cfg.max_local_epoch+server_local_epoch)+'.ckpt'
|
||||
)
|
||||
save_checkpoint(server_network_train_cell.network, output_path)
|
||||
|
||||
print('Training process is done! Time cost: {}'.format(time() - start))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = parse_args()
|
||||
server_train(args_opt)
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Finish train_cloud.py case")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
scheduler_port = args.scheduler_port
|
||||
|
||||
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
|
||||
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
|
||||
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd])
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run train_cloud.py case")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=0)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
scheduler_manage_port = args.scheduler_manage_port
|
||||
|
||||
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
|
||||
cmd_sched += "python ${self_path}/../cloud_train.py"
|
||||
cmd_sched += " --device_target=" + device_target
|
||||
cmd_sched += " --server_mode=" + server_mode
|
||||
cmd_sched += " --ms_role=MS_SCHED"
|
||||
cmd_sched += " --worker_num=" + str(worker_num)
|
||||
cmd_sched += " --server_num=" + str(server_num)
|
||||
cmd_sched += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_sched += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
|
||||
cmd_sched += " > scheduler.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_sched])
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run train_cloud.py case")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=0)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
# The number of servers that this script will launch.
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=0.05)
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
local_server_num = args.local_server_num
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
encrypt_type = args.encrypt_type
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
||||
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
|
||||
|
||||
for i in range(local_server_num):
|
||||
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
|
||||
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
|
||||
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
|
||||
cmd_server += "python ${self_path}/../cloud_train.py"
|
||||
cmd_server += " --device_target=" + device_target
|
||||
cmd_server += " --server_mode=" + server_mode
|
||||
cmd_server += " --ms_role=MS_SERVER"
|
||||
cmd_server += " --worker_num=" + str(worker_num)
|
||||
cmd_server += " --server_num=" + str(server_num)
|
||||
cmd_server += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_server += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
|
||||
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
|
||||
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
|
||||
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
|
||||
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
|
||||
cmd_server += " --fl_name=" + fl_name
|
||||
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --dp_eps=" + str(dp_eps)
|
||||
cmd_server += " --dp_delta=" + str(dp_delta)
|
||||
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
time.sleep(0.3)
|
||||
subprocess.call(['bash', '-c', cmd_server])
|
|
@ -0,0 +1,424 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
||||
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
_scaler_one = Tensor(1, mstype.int32)
|
||||
_scaler_ten = Tensor(10, mstype.float32)
|
||||
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
|
||||
"""
|
||||
Update parameters by AdamWeightDecay op.
|
||||
"""
|
||||
if optim_filter:
|
||||
adam = P.AdamWeightDecay()
|
||||
if decay_flags:
|
||||
next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
|
||||
else:
|
||||
next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
|
||||
return next_param
|
||||
return gradient
|
||||
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
|
||||
"""
|
||||
Update parameters.
|
||||
|
||||
Args:
|
||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
lr (Tensor): Learning rate.
|
||||
overflow (Tensor): Whether overflow occurs.
|
||||
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
|
||||
param (Tensor): Parameters.
|
||||
m (Tensor): m value of parameters.
|
||||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
decay_flag (bool): Applies weight decay or not.
|
||||
optim_filter (bool): Applies parameter update or not.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
op_select = P.Select()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
|
||||
next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
|
||||
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
|
||||
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
|
||||
|
||||
update = next_m / (eps + op_sqrt(next_v))
|
||||
if decay_flag:
|
||||
update = op_mul(weight_decay, param_fp32) + update
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
|
||||
next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
|
||||
|
||||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
|
||||
return op_cast(next_param, F.dtype(param))
|
||||
return gradient
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
||||
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
|
||||
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
indices = gradient.indices
|
||||
values = gradient.values
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
shapes = (op_shape(param), op_shape(m), op_shape(v),
|
||||
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
||||
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
|
||||
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, values, indices), shapes), param))
|
||||
return success
|
||||
|
||||
if not target:
|
||||
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, values, indices))
|
||||
else:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
scatter_add = P.ScatterAdd(use_locking)
|
||||
|
||||
assign_m = F.assign(m, op_mul(beta1, m))
|
||||
assign_v = F.assign(v, op_mul(beta2, v))
|
||||
|
||||
grad_indices = gradient.indices
|
||||
grad_value = gradient.values
|
||||
|
||||
next_m = scatter_add(m,
|
||||
grad_indices,
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||
|
||||
next_v = scatter_add(v,
|
||||
grad_indices,
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
|
||||
|
||||
if use_nesterov:
|
||||
m_temp = next_m * _scaler_ten
|
||||
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
||||
div_value = scatter_add(m,
|
||||
op_mul(grad_indices, _scaler_one),
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||
|
||||
m_recover = F.assign(m, m_temp / _scaler_ten)
|
||||
|
||||
F.control_depend(m_temp, assign_m_nesterov)
|
||||
F.control_depend(assign_m_nesterov, div_value)
|
||||
F.control_depend(param_update, m_recover)
|
||||
else:
|
||||
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||
|
||||
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
||||
|
||||
next_param = param - lr_t * param_update
|
||||
|
||||
F.control_depend(assign_m, next_m)
|
||||
F.control_depend(assign_v, next_v)
|
||||
|
||||
success = F.depend(success, F.assign(param, next_param))
|
||||
success = F.depend(success, F.assign(m, next_m))
|
||||
success = F.depend(success, F.assign(v, next_v))
|
||||
|
||||
return success
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
|
||||
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
|
||||
moment1, moment2, ps_parameter, cache_enable):
|
||||
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
||||
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
|
||||
else:
|
||||
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient))
|
||||
return success
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor")
|
||||
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
|
||||
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
|
||||
success = F.depend(success, F.assign_add(param, delat_param))
|
||||
return success
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
|
||||
class AdamWeightDecayForBert(Optimizer):
|
||||
"""
|
||||
Implements the Adam algorithm to fix the weight decay.
|
||||
|
||||
Note:
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||
which in the 'order_params' must be in one of group parameters.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||
Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
- **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
|
||||
|
||||
Outputs:
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = AdamWeightDecay(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.op_select = P.Select()
|
||||
self.op_cast = P.Cast()
|
||||
self.op_reshape = P.Reshape()
|
||||
self.op_shape = P.Shape()
|
||||
|
||||
def construct(self, gradients, overflow):
|
||||
"""AdamWeightDecayForBert"""
|
||||
lr = self.get_lr()
|
||||
cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
|
||||
self.op_reshape(overflow, (())), mstype.bool_)
|
||||
beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
|
||||
beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
||||
|
||||
class AdamWeightDecayOp(Optimizer):
|
||||
"""
|
||||
Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
|
||||
|
||||
Note:
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||
which in the 'order_params' must be in one of group parameters.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||
Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = AdamWeightDecayOp(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, gradients):
|
||||
"""AdamWeightDecayOp"""
|
||||
lr = self.get_lr()
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""assessment methods"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Accuracy:
|
||||
"""Accuracy"""
|
||||
def __init__(self):
|
||||
self.acc_num = 0
|
||||
self.total_num = 0
|
||||
self.name = 'Accuracy'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
self.acc_num += np.sum(labels == logit_id)
|
||||
self.total_num += len(labels)
|
||||
|
||||
def get_metrics(self):
|
||||
return self.acc_num / self.total_num * 100.0
|
||||
|
||||
|
||||
class TopK:
|
||||
"""F1"""
|
||||
def __init__(self, k=5):
|
||||
self.acc_num = 0
|
||||
self.total_num = 0
|
||||
self.k = k
|
||||
self.name = 'Top' + str(k)
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
logits = logits.asnumpy()
|
||||
sorted_index = logits.argsort()
|
||||
for i, label in enumerate(labels):
|
||||
for j in range(self.k):
|
||||
if sorted_index[i, -j-1] == label:
|
||||
self.acc_num += 1
|
||||
break
|
||||
self.total_num += len(labels)
|
||||
|
||||
def get_metrics(self):
|
||||
return self.acc_num / self.total_num * 100.0
|
||||
|
||||
|
||||
class F1:
|
||||
"""F1"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'F1'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.argmax(logits, axis=1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
|
||||
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
tp = np.sum(self.labels_array & self.logits_array)
|
||||
fp = np.sum(self.labels_array & (~self.logits_array))
|
||||
fn = np.sum((~self.labels_array) & self.logits_array)
|
||||
p = tp / (tp + fp)
|
||||
r = tp / (tp + fn)
|
||||
return 2.0 * p * r / (p + r) * 100.0
|
||||
|
||||
|
||||
class Pearsonr:
|
||||
"""Pearsonr"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'Pearsonr'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.reshape(logits, -1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels])
|
||||
self.logits_array = np.concatenate([self.logits_array, logits])
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
x_mean = self.logits_array.mean()
|
||||
y_mean = self.labels_array.mean()
|
||||
xm = self.logits_array - x_mean
|
||||
ym = self.labels_array - y_mean
|
||||
norm_xm = np.linalg.norm(xm)
|
||||
norm_ym = np.linalg.norm(ym)
|
||||
return np.dot(xm / norm_xm, ym / norm_ym) * 100.0
|
||||
|
||||
|
||||
class Matthews:
|
||||
"""Matthews"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'Matthews'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.argmax(logits, axis=1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
|
||||
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
tp = np.sum(self.labels_array & self.logits_array)
|
||||
fp = np.sum(self.labels_array & (~self.logits_array))
|
||||
fn = np.sum((~self.labels_array) & self.logits_array)
|
||||
tn = np.sum((~self.labels_array) & (~self.logits_array))
|
||||
return (tp * tn - fp * fn) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) * 100.0
|
|
@ -0,0 +1,299 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
|
||||
|
||||
class ClipByNorm(nn.Cell):
|
||||
"""
|
||||
Clips tensor values to a maximum :math:`L_2`-norm.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ClipByNorm, self).__init__()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.select_ = P.Select()
|
||||
self.greater_ = P.Greater()
|
||||
self.cast = P.Cast()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.max_op = P.Maximum()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.fill = P.Fill()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self, x, clip_norm):
|
||||
"""add ms_function decorator for pynative mode"""
|
||||
mul_x = F.square(x)
|
||||
if mul_x.shape == (1,):
|
||||
l2sum = self.cast(mul_x, mstype.float32)
|
||||
else:
|
||||
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
|
||||
cond = self.greater_(l2sum, 0)
|
||||
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
|
||||
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
||||
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
||||
|
||||
intermediate = x * clip_norm
|
||||
|
||||
max_norm = self.max_op(l2norm, clip_norm)
|
||||
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
|
||||
values_clip = self.reshape(values_clip, self.shape(x))
|
||||
values_clip = F.identity(values_clip)
|
||||
return values_clip
|
||||
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
"""clip gradients"""
|
||||
if clip_type not in (0, 1):
|
||||
return grads
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grads = new_grads + (t,)
|
||||
return new_grads
|
||||
|
||||
|
||||
class CrossEntropy(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
"""
|
||||
def __init__(self, num_labels):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.num_labels = num_labels
|
||||
|
||||
def construct(self, logits, label_ids):
|
||||
label_ids = self.reshape(label_ids, self.last_idx)
|
||||
one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
|
||||
loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
return_value = self.cast(loss, mstype.float32)
|
||||
return return_value
|
||||
|
||||
|
||||
class NetworkWithCLSLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetworkWithCLSLoss, self).__init__(auto_prefix=False)
|
||||
self.cls_network = network
|
||||
self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
logits = self.cls_network(input_ids, input_mask, token_type_id)
|
||||
cls_loss = self.loss_fct(logits, label_ids)
|
||||
return cls_loss
|
||||
|
||||
|
||||
class NetworkWithMLMLoss(nn.Cell):
|
||||
def __init__(self, network, vocab_size=21128):
|
||||
super(NetworkWithMLMLoss, self).__init__(auto_prefix=False)
|
||||
self.mlm_network = network
|
||||
self.vocab_size = vocab_size
|
||||
self.reshape = P.Reshape()
|
||||
self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
prediction_scores = self.mlm_network(input_ids, input_mask, token_type_id)
|
||||
prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size))
|
||||
label_ids = self.reshape(label_ids, (-1,))
|
||||
mlm_loss = self.loss_fct(prediction_scores, label_ids)
|
||||
return mlm_loss
|
||||
|
||||
|
||||
class NetworkTrainCell(nn.Cell):
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(NetworkTrainCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.sens = sens
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.clip_type = 1
|
||||
self.clip_value = 1.0
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.get_weights_by_key = P.PullWeight()
|
||||
self.over_weights_by_key = P.PushWeight()
|
||||
|
||||
self.get_weights_by_key_input_1, \
|
||||
self.get_weights_by_key_input_2, \
|
||||
self.get_weights_by_key_input_3 = self._get_weights_by_key_inputs(self.network.parameters_and_names())
|
||||
|
||||
self.over_weights_by_key_input_1, \
|
||||
self.over_weights_by_key_input_2, \
|
||||
self.over_weights_by_key_input_3 = self._over_weights_by_key_inputs(self.network.parameters_and_names())
|
||||
|
||||
def _communication_with_server_1(self, weights):
|
||||
result = self.hyper_map(F.partial(self.get_weights_by_key), weights,
|
||||
self.get_weights_by_key_input_2, self.get_weights_by_key_input_3)
|
||||
return result
|
||||
|
||||
def _communication_with_server_2(self, weights):
|
||||
result = self.hyper_map(F.partial(self.over_weights_by_key), weights,
|
||||
self.over_weights_by_key_input_2,
|
||||
self.over_weights_by_key_input_3)
|
||||
return result
|
||||
|
||||
def _get_weights_by_key_inputs(self, weights):
|
||||
filtered_weights = []
|
||||
weight_names = []
|
||||
weight_indices = []
|
||||
index = 0
|
||||
for weight in weights:
|
||||
if weight[1].pull_weight_from_server:
|
||||
filtered_weights.append(weight[1])
|
||||
weight_names.append(weight[1].name)
|
||||
weight_indices.append(index)
|
||||
index += 1
|
||||
return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices)
|
||||
|
||||
def _over_weights_by_key_inputs(self, weights):
|
||||
filtered_weights = []
|
||||
weight_names = []
|
||||
weight_indices = []
|
||||
index = 0
|
||||
for weight in weights:
|
||||
if weight[1].push_weight_to_server:
|
||||
filtered_weights.append(weight[1])
|
||||
weight_names.append(weight[1].name)
|
||||
weight_indices.append(index)
|
||||
index += 1
|
||||
return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
weights = self.weights
|
||||
res = self._communication_with_server_1(self.get_weights_by_key_input_1)
|
||||
input_ids = F.depend(input_ids, res)
|
||||
loss = self.network(input_ids, input_mask, token_type_id, label_ids)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
weights1 = F.depend(self.over_weights_by_key_input_1, loss)
|
||||
loss = F.depend(loss, self._communication_with_server_2(weights1))
|
||||
return loss
|
||||
|
||||
|
||||
class NetworkNoClientTrainCell(nn.Cell):
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(NetworkNoClientTrainCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.sens = sens
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.clip_type = 1
|
||||
self.clip_value = 1.0
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids, input_mask, token_type_id, label_ids)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""config script"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
from mindspore.common import dtype as mstype
|
||||
from src.model import AlbertConfig
|
||||
|
||||
|
||||
gradient_cfg = edict({
|
||||
'clip_type': 1,
|
||||
'clip_value': 1.0
|
||||
})
|
||||
|
||||
|
||||
train_cfg = edict({
|
||||
'batch_size': 16,
|
||||
'loss_scale_value': 2 ** 16,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 50,
|
||||
'max_global_epoch': 10, #fl_iteration_num
|
||||
'server_cfg': edict({
|
||||
'learning_rate': 1e-5,
|
||||
'max_local_epoch': 1,
|
||||
'cyclic_trunc': False
|
||||
}),
|
||||
'client_cfg': edict({
|
||||
'learning_rate': 1e-5,
|
||||
'max_local_epoch': 1,
|
||||
'num_per_epoch': 20,
|
||||
'cyclic_trunc': True
|
||||
}),
|
||||
'optimizer_cfg': edict({
|
||||
'AdamWeightDecay': edict({
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'norm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'warmup_ratio': 0.1
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
eval_cfg = edict({
|
||||
'batch_size': 256,
|
||||
})
|
||||
|
||||
server_net_cfg = AlbertConfig(
|
||||
seq_length=8,
|
||||
vocab_size=11682,
|
||||
hidden_size=312,
|
||||
num_hidden_groups=1,
|
||||
num_hidden_layers=4,
|
||||
inner_group_num=1,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=1248,
|
||||
hidden_act="gelu",
|
||||
query_act=None,
|
||||
key_act=None,
|
||||
value_act=None,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
classifier_dropout_prob=0.0,
|
||||
embedding_size=128,
|
||||
layer_norm_eps=1e-12,
|
||||
has_attention_mask=True,
|
||||
do_return_2d_tensor=True,
|
||||
use_one_hot_embeddings=False,
|
||||
use_token_type=True,
|
||||
return_all_encoders=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
is_training=True,
|
||||
num_labels=4,
|
||||
use_word_embeddings=True
|
||||
)
|
||||
|
||||
client_net_cfg = AlbertConfig(
|
||||
seq_length=8,
|
||||
vocab_size=11682,
|
||||
hidden_size=312,
|
||||
num_hidden_groups=1,
|
||||
num_hidden_layers=4,
|
||||
inner_group_num=1,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=1248,
|
||||
hidden_act="gelu",
|
||||
query_act=None,
|
||||
key_act=None,
|
||||
value_act=None,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
classifier_dropout_prob=0.0,
|
||||
embedding_size=128,
|
||||
layer_norm_eps=1e-12,
|
||||
has_attention_mask=True,
|
||||
do_return_2d_tensor=True,
|
||||
use_one_hot_embeddings=False,
|
||||
use_token_type=True,
|
||||
return_all_encoders=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
is_training=True,
|
||||
num_labels=4,
|
||||
use_word_embeddings=True
|
||||
)
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
from mindspore import dataset as ds
|
||||
from mindspore.dataset.transforms import c_transforms as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class InputFeatures:
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self, input_ids, input_mask, segment_ids, label_id, seq_length=None):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
self.seq_length = seq_length
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, cyclic_trunc=False):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
label_map = {label: _ for _, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for example in examples:
|
||||
tokens = tokenizer.tokenize(example[0])
|
||||
seq_length = len(tokens)
|
||||
if seq_length > max_seq_length - 2:
|
||||
if cyclic_trunc:
|
||||
rand_index = np.random.randint(0, seq_length)
|
||||
tokens = [tokens[_] if _ < seq_length else tokens[_ - seq_length]
|
||||
for _ in range(rand_index, rand_index + max_seq_length - 2)]
|
||||
else:
|
||||
tokens = tokens[: (max_seq_length - 2)]
|
||||
|
||||
tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
segment_ids = [0] * len(tokens)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
seq_length = len(input_ids)
|
||||
|
||||
padding = [0] * (max_seq_length - len(input_ids))
|
||||
input_ids += padding
|
||||
input_mask += padding
|
||||
segment_ids += padding
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
label_id = label_map[example[1]]
|
||||
|
||||
features.append(InputFeatures(input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
label_id=label_id,
|
||||
seq_length=seq_length))
|
||||
return features
|
||||
|
||||
|
||||
def load_dataset(data_path, max_seq_length, tokenizer, batch_size, label_list=None, do_shuffle=True,
|
||||
drop_remainder=True, output_dir=None, i=0, cyclic_trunc=False):
|
||||
if label_list is None:
|
||||
label_list = ['good', 'leimu', 'xiaoku', 'xin']
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
data = f.read()
|
||||
data_list = data.split('\n<<<')
|
||||
input_list = []
|
||||
for key in data_list[1:]:
|
||||
key = key.split('>>>')
|
||||
input_list.append([key[1], key[0]])
|
||||
datasets = create_ms_dataset(input_list, label_list, max_seq_length, tokenizer, batch_size,
|
||||
do_shuffle=do_shuffle, drop_remainder=drop_remainder, cyclic_trunc=cyclic_trunc)
|
||||
if output_dir is not None:
|
||||
output_path = os.path.join(output_dir, str(i) + '.dat')
|
||||
print(output_path)
|
||||
with open(output_path, "wb") as f:
|
||||
pickle.dump(tuple(datasets), f)
|
||||
del data, data_list, input_list
|
||||
return datasets, len(label_list)
|
||||
|
||||
|
||||
def load_datasets(data_dir, max_seq_length, tokenizer, batch_size, label_list=None, do_shuffle=True,
|
||||
drop_remainder=True, output_dir=None, cyclic_trunc=False):
|
||||
if label_list is None:
|
||||
label_list = ['good', 'leimu', 'xiaoku', 'xin']
|
||||
data_path_list = os.listdir(data_dir)
|
||||
datasets_list = []
|
||||
for i, relative_path in enumerate(data_path_list):
|
||||
data_path = os.path.join(data_dir, relative_path)
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
data = f.read()
|
||||
data_list = data.split('\n<<<')
|
||||
input_list = []
|
||||
for key in data_list[1:]:
|
||||
key = key.split('>>>')
|
||||
input_list.append([key[1], key[0]])
|
||||
datasets = create_ms_dataset(input_list, label_list, max_seq_length, tokenizer, batch_size,
|
||||
do_shuffle=do_shuffle, drop_remainder=drop_remainder, cyclic_trunc=cyclic_trunc)
|
||||
if output_dir is not None:
|
||||
output_path = os.path.join(output_dir, str(i) + '.dat')
|
||||
print(output_path)
|
||||
with open(output_path, "wb") as f:
|
||||
pickle.dump(tuple(datasets.create_tuple_iterator()), f)
|
||||
datasets_list.append(datasets)
|
||||
return datasets_list, len(label_list)
|
||||
|
||||
|
||||
def create_ms_dataset(data_list, label_list, max_seq_length, tokenizer, batch_size, do_shuffle=True,
|
||||
drop_remainder=True, cyclic_trunc=False):
|
||||
features = convert_examples_to_features(data_list, label_list, max_seq_length, tokenizer,
|
||||
cyclic_trunc=cyclic_trunc)
|
||||
|
||||
def generator_func():
|
||||
for feature in features:
|
||||
yield (np.array(feature.input_ids),
|
||||
np.array(feature.input_mask),
|
||||
np.array(feature.segment_ids),
|
||||
np.array(feature.label_id),
|
||||
np.array(feature.seq_length))
|
||||
|
||||
dataset = ds.GeneratorDataset(generator_func,
|
||||
['input_ids', 'input_mask', 'token_type_id', 'label_ids', 'seq_length'])
|
||||
if do_shuffle:
|
||||
dataset = dataset.shuffle(buffer_size=10000)
|
||||
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
dataset = dataset.map(operations=[type_cast_op])
|
||||
dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
|
||||
return dataset
|
||||
|
||||
|
||||
class ConstructMaskAndReplaceTensor:
|
||||
def __init__(self, batch_size, max_seq_length, vocab_size, keep_first_unchange=True, keep_last_unchange=True):
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_length = max_seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.keep_first_unchange = keep_first_unchange
|
||||
self.keep_last_unchange = keep_last_unchange
|
||||
self.mask_tensor = np.ones((self.batch_size, self.max_seq_length))
|
||||
self.replace_tensor = np.zeros((self.batch_size, self.max_seq_length))
|
||||
|
||||
def construct(self, seq_lengths):
|
||||
for i in range(self.batch_size):
|
||||
for j in range(seq_lengths[i]):
|
||||
rand1 = np.random.random()
|
||||
if rand1 < 0.15:
|
||||
self.mask_tensor[i, j] = 0
|
||||
rand2 = np.random.random()
|
||||
if rand2 < 0.8:
|
||||
self.replace_tensor[i, j] = 103
|
||||
elif rand2 < 0.9:
|
||||
self.mask_tensor[i, j] = 1
|
||||
else:
|
||||
self.replace_tensor[i, j] = np.random.randint(0, self.vocab_size)
|
||||
else:
|
||||
self.mask_tensor[i, j] = 1
|
||||
self.replace_tensor[i, j] = 0
|
||||
for j in range(seq_lengths[i], self.max_seq_length):
|
||||
self.mask_tensor[i, j] = 1
|
||||
self.replace_tensor[i, j] = 0
|
||||
if self.keep_first_unchange:
|
||||
self.mask_tensor[i, 0] = 1
|
||||
self.replace_tensor[i, 0] = 0
|
||||
if self.keep_last_unchange:
|
||||
self.mask_tensor[i, seq_lengths[i] - 1] = 1
|
||||
self.replace_tensor[i, seq_lengths[i] - 1] = 0
|
||||
mask_tensor = Tensor(self.mask_tensor, dtype=mstype.int32)
|
||||
replace_tensor = Tensor(self.replace_tensor, dtype=mstype.int32)
|
||||
return mask_tensor, replace_tensor
|
|
@ -0,0 +1,892 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class AlbertConfig:
|
||||
"""
|
||||
Configuration for `AlbertModel`.
|
||||
|
||||
Args:
|
||||
seq_length (int): Length of input sequence. Default: 128.
|
||||
vocab_size (int): The shape of each embedding vector. Default: 32000.
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
|
||||
cell. Default: 12.
|
||||
num_attention_heads (int): Number of attention heads in the BertTransformer
|
||||
encoder cell. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in the BertTransformer
|
||||
encoder cell. Default: 3072.
|
||||
hidden_act (str): Activation function used in the BertTransformer encoder
|
||||
cell. Default: "gelu".
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seq_length=256,
|
||||
vocab_size=21128,
|
||||
hidden_size=312,
|
||||
num_hidden_groups=1,
|
||||
num_hidden_layers=4,
|
||||
inner_group_num=1,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=1248,
|
||||
hidden_act="gelu",
|
||||
query_act=None,
|
||||
key_act=None,
|
||||
value_act=None,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
classifier_dropout_prob=0.1,
|
||||
embedding_size=128,
|
||||
layer_norm_eps=1e-12,
|
||||
has_attention_mask=True,
|
||||
do_return_2d_tensor=True,
|
||||
use_one_hot_embeddings=False,
|
||||
use_token_type=True,
|
||||
return_all_encoders=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
is_training=True,
|
||||
num_labels=5,
|
||||
use_word_embeddings=True):
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.inner_group_num = inner_group_num
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.query_act = query_act
|
||||
self.key_act = key_act
|
||||
self.value_act = value_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.classifier_dropout_prob = classifier_dropout_prob
|
||||
self.embedding_size = embedding_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.num_hidden_groups = num_hidden_groups
|
||||
self.has_attention_mask = has_attention_mask
|
||||
self.do_return_2d_tensor = do_return_2d_tensor
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.use_token_type = use_token_type
|
||||
self.return_all_encoders = return_all_encoders
|
||||
self.output_attentions = output_attentions
|
||||
self.output_hidden_states = output_hidden_states
|
||||
self.dtype = dtype
|
||||
self.compute_type = compute_type
|
||||
self.is_training = is_training
|
||||
self.num_labels = num_labels
|
||||
self.use_word_embeddings = use_word_embeddings
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.use_one_hot_embeddings = config.use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(config.initializer_range),
|
||||
[config.vocab_size, config.embedding_size]),
|
||||
name='embedding_table')
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, config.seq_length, config.embedding_size)
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""embedding lookup"""
|
||||
flat_ids = self.reshape(input_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
output_for_reshape = self.array_mul(
|
||||
one_hot_ids, self.embedding_table)
|
||||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
output = self.reshape(output_for_reshape, self.shape)
|
||||
return output, self.embedding_table
|
||||
|
||||
|
||||
class EmbeddingPostprocessor(nn.Cell):
|
||||
"""
|
||||
Postprocessors apply positional and token type embeddings to word embeddings.
|
||||
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(EmbeddingPostprocessor, self).__init__()
|
||||
self.use_token_type = config.use_token_type
|
||||
self.token_type_vocab_size = config.type_vocab_size
|
||||
self.use_one_hot_embeddings = config.use_one_hot_embeddings
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(config.initializer_range),
|
||||
[config.type_vocab_size,
|
||||
config.embedding_size]))
|
||||
self.shape_flat = (-1,)
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.1, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, config.seq_length, config.embedding_size)
|
||||
self.layernorm = nn.LayerNorm((config.embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - config.hidden_dropout_prob)
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = config.use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
(TruncatedNormal(config.initializer_range),
|
||||
[config.max_position_embeddings,
|
||||
config.embedding_size]))
|
||||
|
||||
def construct(self, token_type_ids, word_embeddings):
|
||||
"""embedding postprocessor"""
|
||||
output = word_embeddings
|
||||
if self.use_token_type:
|
||||
flat_ids = self.reshape(token_type_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids,
|
||||
self.token_type_vocab_size, self.on_value, self.off_value)
|
||||
token_type_embeddings = self.array_mul(one_hot_ids,
|
||||
self.embedding_table)
|
||||
else:
|
||||
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
|
||||
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
|
||||
output += token_type_embeddings
|
||||
if not self.use_relative_positions:
|
||||
_, seq, width = self.shape
|
||||
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
|
||||
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
|
||||
output += position_embeddings
|
||||
output = self.layernorm(output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class AlbertOutput(nn.Cell):
|
||||
"""
|
||||
Apply a linear computation to hidden status and a residual computation to input.
|
||||
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertOutput, self).__init__()
|
||||
self.dense = nn.Dense(config.hidden_size, config.hidden_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
self.dropout = nn.Dropout(1 - config.hidden_dropout_prob)
|
||||
self.add = P.Add()
|
||||
self.is_gpu = context.get_context('device_target') == "GPU"
|
||||
if self.is_gpu:
|
||||
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(mstype.float32)
|
||||
self.compute_type = config.compute_type
|
||||
else:
|
||||
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
|
||||
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, hidden_status, input_tensor):
|
||||
"""bert output"""
|
||||
output = self.dense(hidden_status)
|
||||
output = self.dropout(output)
|
||||
output = self.add(input_tensor, output)
|
||||
output = self.layernorm(output)
|
||||
if self.is_gpu:
|
||||
output = self.cast(output, self.compute_type)
|
||||
return output
|
||||
|
||||
|
||||
class RelaPosMatrixGenerator(nn.Cell):
|
||||
"""
|
||||
Generates matrix of relative positions between inputs.
|
||||
|
||||
Args:
|
||||
length (int): Length of one dim for the matrix to be generated.
|
||||
max_relative_position (int): Max value of relative position.
|
||||
"""
|
||||
|
||||
def __init__(self, length, max_relative_position):
|
||||
super(RelaPosMatrixGenerator, self).__init__()
|
||||
self._length = length
|
||||
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
|
||||
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
|
||||
self.range_length = -length + 1
|
||||
self.tile = P.Tile()
|
||||
self.range_mat = P.Reshape()
|
||||
self.sub = P.Sub()
|
||||
self.expanddims = P.ExpandDims()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self):
|
||||
"""position matrix generator"""
|
||||
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
|
||||
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
|
||||
tile_row_out = self.tile(range_vec_row_out, (self._length,))
|
||||
tile_col_out = self.tile(range_vec_col_out, (1, self._length))
|
||||
range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
|
||||
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
|
||||
distance_mat = self.sub(range_mat_out, transpose_out)
|
||||
distance_mat_clipped = C.clip_by_value(distance_mat,
|
||||
self._min_relative_position,
|
||||
self._max_relative_position)
|
||||
# Shift values to be >=0. Each integer still uniquely identifies a
|
||||
# relative position difference.
|
||||
final_mat = distance_mat_clipped + self._max_relative_position
|
||||
return final_mat
|
||||
|
||||
|
||||
class RelaPosEmbeddingsGenerator(nn.Cell):
|
||||
"""
|
||||
Generates tensor of size [length, length, depth].
|
||||
|
||||
Args:
|
||||
length (int): Length of one dim for the matrix to be generated.
|
||||
depth (int): Size of each attention head.
|
||||
max_relative_position (int): Maxmum value of relative position.
|
||||
initializer_range (float): Initialization value of TruncatedNormal.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
length,
|
||||
depth,
|
||||
max_relative_position,
|
||||
initializer_range,
|
||||
use_one_hot_embeddings=False):
|
||||
super(RelaPosEmbeddingsGenerator, self).__init__()
|
||||
self.depth = depth
|
||||
self.vocab_size = max_relative_position * 2 + 1
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embeddings_table = Parameter(
|
||||
initializer(TruncatedNormal(initializer_range),
|
||||
[self.vocab_size, self.depth]),
|
||||
name='embeddings_for_position')
|
||||
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
||||
max_relative_position=max_relative_position)
|
||||
self.reshape = P.Reshape()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
"""position embedding generation"""
|
||||
relative_positions_matrix_out = self.relative_positions_matrix()
|
||||
# Generate embedding for each relative position of dimension depth.
|
||||
if self.use_one_hot_embeddings:
|
||||
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
||||
one_hot_relative_positions_matrix = self.one_hot(
|
||||
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
|
||||
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
||||
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
||||
embeddings = self.reshape(embeddings, my_shape)
|
||||
else:
|
||||
embeddings = self.gather(self.embeddings_table,
|
||||
relative_positions_matrix_out, 0)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SaturateCast(nn.Cell):
|
||||
"""
|
||||
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
|
||||
the danger that the value will overflow or underflow.
|
||||
|
||||
Args:
|
||||
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
|
||||
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
||||
super(SaturateCast, self).__init__()
|
||||
np_type = mstype.dtype_to_nptype(dst_type)
|
||||
min_type = np.finfo(np_type).min
|
||||
max_type = np.finfo(np_type).max
|
||||
self.tensor_min_type = Tensor([min_type], dtype=src_type)
|
||||
self.tensor_max_type = Tensor([max_type], dtype=src_type)
|
||||
self.min_op = P.Minimum()
|
||||
self.max_op = P.Maximum()
|
||||
self.cast = P.Cast()
|
||||
self.dst_type = dst_type
|
||||
|
||||
def construct(self, x):
|
||||
"""saturate cast"""
|
||||
out = self.max_op(x, self.tensor_min_type)
|
||||
out = self.min_op(out, self.tensor_max_type)
|
||||
return self.cast(out, self.dst_type)
|
||||
|
||||
|
||||
class AlbertAttention(nn.Cell):
|
||||
"""
|
||||
Apply multi-headed attention from "from_tensor" to "to_tensor".
|
||||
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertAttention, self).__init__()
|
||||
self.from_seq_length = config.seq_length
|
||||
self.to_seq_length = config.seq_length
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.size_per_head = int(config.hidden_size / config.num_attention_heads)
|
||||
self.has_attention_mask = config.has_attention_mask
|
||||
self.use_relative_positions = config.use_relative_positions
|
||||
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=config.compute_type)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_from_2d = (-1, config.hidden_size)
|
||||
self.shape_to_2d = (-1, config.hidden_size)
|
||||
weight = TruncatedNormal(config.initializer_range)
|
||||
|
||||
self.query = nn.Dense(config.hidden_size,
|
||||
config.hidden_size,
|
||||
activation=config.query_act,
|
||||
weight_init=weight).to_float(config.compute_type)
|
||||
self.key = nn.Dense(config.hidden_size,
|
||||
config.hidden_size,
|
||||
activation=config.key_act,
|
||||
weight_init=weight).to_float(config.compute_type)
|
||||
self.value = nn.Dense(config.hidden_size,
|
||||
config.hidden_size,
|
||||
activation=config.value_act,
|
||||
weight_init=weight).to_float(config.compute_type)
|
||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
self.matmul = P.BatchMatMul()
|
||||
self.shape_from = (-1, config.seq_length, config.num_attention_heads, self.size_per_head)
|
||||
self.shape_to = (-1, config.seq_length, config.num_attention_heads, self.size_per_head)
|
||||
self.multiply = P.Mul()
|
||||
self.transpose = P.Transpose()
|
||||
self.trans_shape = (0, 2, 1, 3)
|
||||
self.trans_shape_relative = (2, 0, 1, 3)
|
||||
self.trans_shape_position = (1, 2, 0, 3)
|
||||
self.multiply_data = Tensor([-10000.0], dtype=config.compute_type)
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - config.attention_probs_dropout_prob)
|
||||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sub = P.Sub()
|
||||
self.add = P.Add()
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
if config.do_return_2d_tensor:
|
||||
self.shape_return = (-1, config.hidden_size)
|
||||
else:
|
||||
self.shape_return = (-1, config.seq_length, config.hidden_size)
|
||||
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||
if self.use_relative_positions:
|
||||
self._generate_relative_positions_embeddings = \
|
||||
RelaPosEmbeddingsGenerator(length=config.seq_length,
|
||||
depth=self.size_per_head,
|
||||
max_relative_position=16,
|
||||
initializer_range=config.initializer_range,
|
||||
use_one_hot_embeddings=config.use_one_hot_embeddings)
|
||||
|
||||
def construct(self, from_tensor, to_tensor, attention_mask):
|
||||
"""bert attention"""
|
||||
# reshape 2d/3d input tensors to 2d
|
||||
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
||||
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
||||
query_out = self.query(from_tensor_2d)
|
||||
key_out = self.key(to_tensor_2d)
|
||||
value_out = self.value(to_tensor_2d)
|
||||
query_layer = self.reshape(query_out, self.shape_from)
|
||||
query_layer = self.transpose(query_layer, self.trans_shape)
|
||||
key_layer = self.reshape(key_out, self.shape_to)
|
||||
key_layer = self.transpose(key_layer, self.trans_shape)
|
||||
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
# relations_keys is [F|T, F|T, H]
|
||||
relations_keys = self._generate_relative_positions_embeddings()
|
||||
relations_keys = self.cast_compute_type(relations_keys)
|
||||
# query_layer_t is [F, B, N, H]
|
||||
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
|
||||
# query_layer_r is [F, B * N, H]
|
||||
query_layer_r = self.reshape(query_layer_t,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.size_per_head))
|
||||
# key_position_scores is [F, B * N, F|T]
|
||||
key_position_scores = self.matmul_trans_b(query_layer_r,
|
||||
relations_keys)
|
||||
# key_position_scores_r is [F, B, N, F|T]
|
||||
key_position_scores_r = self.reshape(key_position_scores,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.from_seq_length))
|
||||
# key_position_scores_r_t is [B, N, F, F|T]
|
||||
key_position_scores_r_t = self.transpose(key_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
attention_scores = attention_scores + key_position_scores_r_t
|
||||
attention_scores = self.multiply(self.scores_mul, attention_scores)
|
||||
if self.has_attention_mask:
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
|
||||
self.cast(attention_mask, self.get_dtype(attention_scores)))
|
||||
adder = self.multiply(multiply_out, self.multiply_data)
|
||||
attention_scores = self.add(adder, attention_scores)
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
value_layer = self.reshape(value_out, self.shape_to)
|
||||
value_layer = self.transpose(value_layer, self.trans_shape)
|
||||
context_layer = self.matmul(attention_probs, value_layer)
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
# relations_values is [F|T, F|T, H]
|
||||
relations_values = self._generate_relative_positions_embeddings()
|
||||
relations_values = self.cast_compute_type(relations_values)
|
||||
# attention_probs_t is [F, B, N, T]
|
||||
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
|
||||
# attention_probs_r is [F, B * N, T]
|
||||
attention_probs_r = self.reshape(
|
||||
attention_probs_t,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.to_seq_length))
|
||||
# value_position_scores is [F, B * N, H]
|
||||
value_position_scores = self.matmul(attention_probs_r,
|
||||
relations_values)
|
||||
# value_position_scores_r is [F, B, N, H]
|
||||
value_position_scores_r = self.reshape(value_position_scores,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.size_per_head))
|
||||
# value_position_scores_r_t is [B, N, F, H]
|
||||
value_position_scores_r_t = self.transpose(value_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
context_layer = context_layer + value_position_scores_r_t
|
||||
context_layer = self.transpose(context_layer, self.trans_shape)
|
||||
context_layer = self.reshape(context_layer, self.shape_return)
|
||||
return context_layer, attention_scores
|
||||
|
||||
|
||||
class AlbertSelfAttention(nn.Cell):
|
||||
"""
|
||||
Apply self-attention.
|
||||
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||
"of attention heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.attention = AlbertAttention(config)
|
||||
self.output = AlbertOutput(config)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, config.hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""bert self attention"""
|
||||
input_tensor = self.reshape(input_tensor, self.shape)
|
||||
attention_output, attention_scores = self.attention(input_tensor, input_tensor, attention_mask)
|
||||
output = self.output(attention_output, input_tensor)
|
||||
return output, attention_scores
|
||||
|
||||
|
||||
class AlbertEncoderCell(nn.Cell):
|
||||
"""
|
||||
Encoder cells used in BertTransformer.
|
||||
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertEncoderCell, self).__init__()
|
||||
self.attention = AlbertSelfAttention(config)
|
||||
self.intermediate = nn.Dense(in_channels=config.hidden_size,
|
||||
out_channels=config.intermediate_size,
|
||||
activation=config.hidden_act,
|
||||
weight_init=TruncatedNormal(config.initializer_range)
|
||||
).to_float(config.compute_type)
|
||||
self.output = AlbertOutput(config)
|
||||
|
||||
def construct(self, hidden_states, attention_mask):
|
||||
"""bert encoder cell"""
|
||||
# self-attention
|
||||
attention_output, attention_scores = self.attention(hidden_states, attention_mask)
|
||||
# feed construct
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
# add and normalize
|
||||
output = self.output(intermediate_output, attention_output)
|
||||
return output, attention_scores
|
||||
|
||||
|
||||
class AlbertLayer(nn.Cell):
|
||||
"""
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(AlbertLayer, self).__init__()
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
self.attention = AlbertSelfAttention(config)
|
||||
self.ffn = nn.Dense(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
activation=config.hidden_act).to_float(config.compute_type)
|
||||
self.ffn_output = nn.Dense(config.intermediate_size, config.hidden_size)
|
||||
self.full_layer_layer_norm = nn.LayerNorm((config.hidden_size,))
|
||||
self.shape = (-1, config.seq_length, config.hidden_size)
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, hidden_states, attention_mask):
|
||||
attention_output, attention_scores = self.attention(hidden_states, attention_mask)
|
||||
|
||||
ffn_output = self.ffn(attention_output)
|
||||
ffn_output = self.ffn_output(ffn_output)
|
||||
ffn_output = self.reshape(ffn_output + attention_output, self.shape)
|
||||
hidden_states = self.full_layer_layer_norm(ffn_output)
|
||||
|
||||
return hidden_states, attention_scores
|
||||
|
||||
|
||||
class AlbertLayerGroup(nn.Cell):
|
||||
"""
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertLayerGroup, self).__init__()
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
self.albert_layers = nn.CellList([AlbertLayer(config) for _ in range(config.inner_group_num)])
|
||||
|
||||
def construct(self, hidden_states, attention_mask):
|
||||
layer_hidden_states = ()
|
||||
layer_attentions = ()
|
||||
|
||||
for _, albert_layer in enumerate(self.albert_layers):
|
||||
layer_output = albert_layer(hidden_states, attention_mask)
|
||||
hidden_states = layer_output[0]
|
||||
if self.output_attentions:
|
||||
layer_attentions = layer_attentions + (layer_output[1],)
|
||||
if self.output_hidden_states:
|
||||
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (layer_attentions,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (layer_hidden_states,)
|
||||
return outputs
|
||||
|
||||
|
||||
class AlbertTransformer(nn.Cell):
|
||||
"""
|
||||
Multi-layer bert transformer.
|
||||
|
||||
Args:
|
||||
config (AlbertConfig): Albert Config.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertTransformer, self).__init__()
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.num_hidden_groups = config.num_hidden_groups
|
||||
self.group_idx_list = [int(_ / (config.num_hidden_layers / config.num_hidden_groups))
|
||||
for _ in range(config.num_hidden_layers)]
|
||||
|
||||
self.embedding_hidden_mapping_in = nn.Dense(config.embedding_size, config.hidden_size)
|
||||
self.return_all_encoders = config.return_all_encoders
|
||||
layers = []
|
||||
for _ in range(config.num_hidden_groups):
|
||||
layer = AlbertLayerGroup(config)
|
||||
layers.append(layer)
|
||||
self.albert_layer_groups = nn.CellList(layers)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, config.embedding_size)
|
||||
self.out_shape = (-1, config.seq_length, config.hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""bert transformer"""
|
||||
prev_output = self.reshape(input_tensor, self.shape)
|
||||
prev_output = self.embedding_hidden_mapping_in(prev_output)
|
||||
all_encoder_layers = ()
|
||||
all_encoder_atts = ()
|
||||
all_encoder_outputs = (prev_output,)
|
||||
# for layer_module in self.layers:
|
||||
for i in range(self.num_hidden_layers):
|
||||
# Index of the hidden group
|
||||
group_idx = self.group_idx_list[i]
|
||||
|
||||
layer_output, encoder_att = self.albert_layer_groups[group_idx](prev_output, attention_mask)
|
||||
prev_output = layer_output
|
||||
if self.return_all_encoders:
|
||||
all_encoder_outputs += (layer_output,)
|
||||
layer_output = self.reshape(layer_output, self.out_shape)
|
||||
all_encoder_layers += (layer_output,)
|
||||
all_encoder_atts += (encoder_att,)
|
||||
if not self.return_all_encoders:
|
||||
prev_output = self.reshape(prev_output, self.out_shape)
|
||||
all_encoder_layers += (prev_output,)
|
||||
return prev_output
|
||||
|
||||
|
||||
class CreateAttentionMaskFromInputMask(nn.Cell):
|
||||
"""
|
||||
Create attention mask according to input mask.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, 1, config.seq_length)
|
||||
|
||||
def construct(self, input_mask):
|
||||
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class AlbertModel(nn.Cell):
|
||||
"""
|
||||
Bidirectional Encoder Representations from Transformers.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertModel, self).__init__()
|
||||
config = copy.deepcopy(config)
|
||||
if not config.is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.attention_probs_dropout_prob = 0.0
|
||||
self.seq_length = config.seq_length
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embedding_size = config.hidden_size
|
||||
self.token_type_ids = None
|
||||
self.last_idx = self.num_hidden_layers - 1
|
||||
self.use_word_embeddings = config.use_word_embeddings
|
||||
if self.use_word_embeddings:
|
||||
self.word_embeddings = EmbeddingLookup(config)
|
||||
self.embedding_postprocessor = EmbeddingPostprocessor(config)
|
||||
self.encoder = AlbertTransformer(config)
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||
self.slice = P.StridedSlice()
|
||||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.pooler = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
"""bert model"""
|
||||
# embedding
|
||||
if self.use_word_embeddings:
|
||||
word_embeddings, _ = self.word_embeddings(input_ids)
|
||||
else:
|
||||
word_embeddings = input_ids
|
||||
embedding_output = self.embedding_postprocessor(token_type_ids, word_embeddings)
|
||||
# attention mask [batch_size, seq_length, seq_length]
|
||||
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
|
||||
# bert encoder
|
||||
encoder_output = self.encoder(self.cast_compute_type(embedding_output), attention_mask)
|
||||
sequence_output = self.cast(encoder_output, self.dtype)
|
||||
# pooler
|
||||
batch_size = P.Shape()(input_ids)[0]
|
||||
sequence_slice = self.slice(sequence_output,
|
||||
(0, 0, 0),
|
||||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
pooled_output = self.pooler(first_token)
|
||||
pooled_output = self.cast(pooled_output, self.dtype)
|
||||
return sequence_output, pooled_output
|
||||
|
||||
|
||||
class AlbertMLMHead(nn.Cell):
|
||||
"""
|
||||
Get masked lm output.
|
||||
|
||||
Args:
|
||||
config (AlbertConfig): The config of BertModel.
|
||||
|
||||
Returns:
|
||||
Tensor, masked lm output.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(AlbertMLMHead, self).__init__()
|
||||
|
||||
self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_type)
|
||||
self.dense = nn.Dense(
|
||||
config.hidden_size,
|
||||
config.embedding_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range),
|
||||
activation=config.hidden_act
|
||||
).to_float(config.compute_type)
|
||||
self.decoder = nn.Dense(
|
||||
config.embedding_size,
|
||||
config.vocab_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range),
|
||||
).to_float(config.compute_type)
|
||||
|
||||
def construct(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AlbertModelCLS(nn.Cell):
|
||||
"""
|
||||
This class is responsible for classification task evaluation,
|
||||
i.e. mnli(num_labels=3), qnli(num_labels=2), qqp(num_labels=2).
|
||||
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertModelCLS, self).__init__()
|
||||
self.albert = AlbertModel(config)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.classifier = nn.Dense(config.hidden_size, config.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.relu = nn.ReLU()
|
||||
self.is_training = config.is_training
|
||||
if self.is_training:
|
||||
self.dropout = nn.Dropout(1 - config.classifier_dropout_prob)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
"""classification albert model"""
|
||||
_, pooled_output = self.albert(input_ids, token_type_id, input_mask)
|
||||
# pooled_output = self.relu(pooled_output)
|
||||
if self.is_training:
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
return logits
|
||||
|
||||
|
||||
class AlbertModelForAD(nn.Cell):
|
||||
"""albert model for ad"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertModelForAD, self).__init__()
|
||||
|
||||
# main model
|
||||
self.albert = AlbertModel(config)
|
||||
|
||||
# classifier head
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
self.classifier = nn.Dense(config.hidden_size, config.num_labels,
|
||||
weight_init=TruncatedNormal(config.initializer_range),
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.is_training = config.is_training
|
||||
if self.is_training:
|
||||
self.dropout = nn.Dropout(1 - config.classifier_dropout_prob)
|
||||
|
||||
# masked language model head
|
||||
self.predictions = AlbertMLMHead(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
"""albert model for ad"""
|
||||
sequence_output, pooled_output = self.albert(input_ids, token_type_id, input_mask)
|
||||
if self.is_training:
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
prediction_scores = self.cast(prediction_scores, self.dtype)
|
||||
return prediction_scores, logits
|
||||
|
||||
|
||||
class AlbertModelMLM(nn.Cell):
|
||||
"""albert model for mlm"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertModelMLM, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
|
||||
# main model
|
||||
self.albert = AlbertModel(config)
|
||||
|
||||
# masked language model head
|
||||
self.predictions = AlbertMLMHead(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
"""albert model for mlm"""
|
||||
sequence_output, _ = self.albert(input_ids, token_type_id, input_mask)
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
prediction_scores = self.cast(prediction_scores, self.dtype)
|
||||
return prediction_scores
|
|
@ -0,0 +1,558 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import unicodedata
|
||||
from io import open
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_vocab(vocab_file, vocab_map_ids_path=None):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
if vocab_map_ids_path is not None:
|
||||
vocab_new_ids = list()
|
||||
with open(vocab_map_ids_path, "r", encoding="utf-8") as vocab_new_ids_reader:
|
||||
while True:
|
||||
index = vocab_new_ids_reader.readline()
|
||||
if not index:
|
||||
break
|
||||
index = index.strip()
|
||||
vocab_new_ids.append(int(index))
|
||||
index = 0
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
while True:
|
||||
token = reader.readline()
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = vocab_new_ids[index]
|
||||
index += 1
|
||||
return vocab
|
||||
index = 0
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
while True:
|
||||
token = reader.readline()
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class BertTokenizer:
|
||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, basic_only=False,
|
||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||
"""Constructs a BertTokenizer.
|
||||
|
||||
Args:
|
||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||
do_lower_case: Whether to lower case the input
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
||||
max_len: An artificial maximum length to truncate tokenized sequences to;
|
||||
Effective maximum length is always the minimum of this
|
||||
value (if specified) and the underlying BERT model's
|
||||
sequence length.
|
||||
never_split: List of tokens which will never be split during tokenization.
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
"""
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict(
|
||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.do_basic_tokenize = do_basic_tokenize
|
||||
if do_basic_tokenize:
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
||||
never_split=never_split)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.basic_only = basic_only
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
if self.basic_only:
|
||||
split_tokens.append(token)
|
||||
else:
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self.vocab.get(token, self.vocab['[UNK]']))
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, 'vocab.txt')
|
||||
else:
|
||||
raise FileNotFoundError
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
return vocab_file
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
|
||||
if 'txt' in pretrained_model_name_or_path:
|
||||
resolved_vocab_file = pretrained_model_name_or_path
|
||||
else:
|
||||
resolved_vocab_file = os.path.join(pretrained_model_name_or_path, 'vocab.txt')
|
||||
|
||||
max_len = 512
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class BasicTokenizer:
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self,
|
||||
do_lower_case=True,
|
||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = never_split
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case and token not in self.never_split:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
@staticmethod
|
||||
def _run_strip_accents(text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if text in self.never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
@staticmethod
|
||||
def _is_chinese_char(cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if (
|
||||
(0x4E00 <= cp <= 0x9FFF) or
|
||||
(0x3400 <= cp <= 0x4DBF) or
|
||||
(0x20000 <= cp <= 0x2A6DF) or
|
||||
(0x2A700 <= cp <= 0x2B73F) or
|
||||
(0x2B740 <= cp <= 0x2B81F) or
|
||||
(0x2B820 <= cp <= 0x2CEAF) or
|
||||
(0xF900 <= cp <= 0xFAFF) or
|
||||
(0x2F800 <= cp <= 0x2FA1F)
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _clean_text(text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer:
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer`.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically control characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char in (" ", "\t", "\n", "\r"):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char in ("\t", "\n", "\r"):
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class CustomizedBasicTokenizer(BasicTokenizer):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), keywords=None):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
super().__init__(do_lower_case, never_split)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = never_split
|
||||
self.keywords = keywords
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
|
||||
if self.keywords is not None:
|
||||
new_orig_tokens = []
|
||||
lengths = [len(_) for _ in self.keywords]
|
||||
max_length = max(lengths)
|
||||
orig_tokens_len = len(orig_tokens)
|
||||
i = 0
|
||||
while i < orig_tokens_len:
|
||||
has_add = False
|
||||
for length in range(max_length, 0, -1):
|
||||
if i + length > orig_tokens_len:
|
||||
continue
|
||||
add_token = ''.join(orig_tokens[i:i+length])
|
||||
if add_token in self.keywords:
|
||||
new_orig_tokens.append(add_token)
|
||||
i += length
|
||||
has_add = True
|
||||
break
|
||||
if not has_add:
|
||||
new_orig_tokens.append(orig_tokens[i])
|
||||
i += 1
|
||||
else:
|
||||
new_orig_tokens = orig_tokens
|
||||
|
||||
split_tokens = []
|
||||
for token in new_orig_tokens:
|
||||
if self.do_lower_case and token not in self.never_split:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
|
||||
class CustomizedTokenizer(BertTokenizer):
|
||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None,
|
||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), keywords=None):
|
||||
"""Constructs a CustomizedTokenizer.
|
||||
|
||||
Args:
|
||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||
do_lower_case: Whether to lower case the input
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
max_len: An artificial maximum length to truncate tokenized sequences to;
|
||||
Effective maximum length is always the minimum of this
|
||||
value (if specified) and the underlying BERT model's
|
||||
sequence length.
|
||||
never_split: List of tokens which will never be split during tokenization.
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
"""
|
||||
super().__init__(vocab_file, do_lower_case, max_len, never_split)
|
||||
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.basic_tokenizer = CustomizedBasicTokenizer(do_lower_case=do_lower_case, never_split=never_split,
|
||||
keywords=keywords)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
basic_tokens = self.basic_tokenizer.tokenize(text)
|
||||
for token in basic_tokens:
|
||||
wordpiece_tokens = self.wordpiece_tokenizer.tokenize(token)
|
||||
for sub_token in wordpiece_tokens:
|
||||
split_tokens.append(sub_token)
|
||||
return split_tokens
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
resolved_vocab_file = os.path.join(pretrained_model_name_or_path, 'customized_vocab.txt')
|
||||
|
||||
max_len = 512
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class CustomizedTextBasicTokenizer(BasicTokenizer):
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case and token not in self.never_split:
|
||||
token = token.lower()
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp) or len(char.encode('utf-8')) > 1:
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class CustomizedTextTokenizer(BertTokenizer):
|
||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None,
|
||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"),
|
||||
vocab_map_ids_path=None):
|
||||
"""Constructs a CustomizedTokenizer.
|
||||
|
||||
Args:
|
||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||
do_lower_case: Whether to lower case the input
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
max_len: An artificial maximum length to truncate tokenized sequences to;
|
||||
Effective maximum length is always the minimum of this
|
||||
value (if specified) and the underlying BERT model's
|
||||
sequence length.
|
||||
never_split: List of tokens which will never be split during tokenization.
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
"""
|
||||
super().__init__(vocab_file, do_lower_case, max_len, never_split)
|
||||
|
||||
self.vocab = load_vocab(vocab_file, vocab_map_ids_path)
|
||||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.basic_tokenizer = CustomizedTextBasicTokenizer(do_lower_case=do_lower_case, never_split=never_split)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import copy
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
||||
def average_weights(para_list):
|
||||
global_parameter = {}
|
||||
length = len(para_list)
|
||||
for para in para_list:
|
||||
for name in para:
|
||||
if name in global_parameter:
|
||||
global_parameter[name] += para[name] / length
|
||||
else:
|
||||
global_parameter[name] = para[name] / length
|
||||
return global_parameter
|
||||
|
||||
|
||||
def save_params(network, param_dict=None):
|
||||
if param_dict is None:
|
||||
return {param.name: copy.deepcopy(param) for param in network.trainable_params()
|
||||
if 'learning_rate' not in param.name and 'adam' not in param.name}
|
||||
for param in network.trainable_params():
|
||||
if param.name in param_dict:
|
||||
param_dict[param.name] = copy.deepcopy(param)
|
||||
return None
|
||||
|
||||
|
||||
def restore_params(network, param_dict, init_adam=True):
|
||||
for param in network.trainable_params():
|
||||
if 'learning_rate' in param.name:
|
||||
continue
|
||||
param.init_data()
|
||||
if init_adam:
|
||||
if 'adam' in param.name:
|
||||
param.set_data(initializer('zeros', shape=param.shape, dtype=param.dtype))
|
||||
elif param.name in param_dict:
|
||||
param.set_data(param_dict[param.name])
|
||||
else:
|
||||
if param.name in param_dict:
|
||||
param.set_data(param_dict[param.name])
|
||||
|
||||
|
||||
def get_worker_upload_list():
|
||||
return [
|
||||
'albert.encoder.embedding_hidden_mapping_in.weight',
|
||||
'albert.encoder.embedding_hidden_mapping_in.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta',
|
||||
'albert.pooler.weight',
|
||||
'albert.pooler.bias',
|
||||
'classifier.weight',
|
||||
'classifier.bias']
|
||||
|
||||
def upload_to_server(network, worker_upload_list):
|
||||
for param in network.trainable_params():
|
||||
if param.name in worker_upload_list:
|
||||
param.set_param_fl(push_to_server=True)
|
||||
|
||||
def get_worker_download_list():
|
||||
return [
|
||||
'albert.encoder.embedding_hidden_mapping_in.weight',
|
||||
'albert.encoder.embedding_hidden_mapping_in.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma',
|
||||
'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta'
|
||||
]
|
||||
|
||||
def download_from_server(network, worker_download_list):
|
||||
for param in network.trainable_params():
|
||||
if param.name in worker_download_list:
|
||||
param.set_param_fl(pull_from_server=True)
|
||||
|
||||
def get_freeze_list():
|
||||
return [
|
||||
'albert.word_embeddings.embedding_table',
|
||||
'albert.embedding_postprocessor.embedding_table',
|
||||
'albert.embedding_postprocessor.full_position_embeddings',
|
||||
'albert.embedding_postprocessor.layernorm.gamma',
|
||||
'albert.embedding_postprocessor.layernorm.beta'
|
||||
]
|
||||
|
||||
def freeze(network, freeze_list):
|
||||
for param in network.trainable_params():
|
||||
if param.name in freeze_list:
|
||||
param.requires_grad = False
|
Loading…
Reference in New Issue