!19605 add federated learning albert

Merge pull request !19605 from wtcheng/master
This commit is contained in:
i-robot 2021-07-08 08:15:42 +00:00 committed by Gitee
commit 6863d52646
13 changed files with 3267 additions and 0 deletions

View File

@ -0,0 +1,101 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import sys
from time import time
from mindspore import context
from mindspore.train.serialization import load_checkpoint
from src.config import eval_cfg, server_net_cfg
from src.dataset import load_datasets
from src.utils import restore_params
from src.model import AlbertModelCLS
from src.tokenization import CustomizedTextTokenizer
from src.assessment_method import Accuracy
def parse_args():
"""
parse args
"""
parser = argparse.ArgumentParser(description='server eval task')
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--device_id', type=str, default='0')
parser.add_argument('--tokenizer_dir', type=str, default='../model_save/init/')
parser.add_argument('--eval_data_dir', type=str, default='../datasets/eval/')
parser.add_argument('--model_path', type=str, default='../model_save/train_server/0.ckpt')
parser.add_argument('--vocab_map_ids_path', type=str, default='../model_save/init/vocab_map_ids.txt')
return parser.parse_args()
def server_eval(args):
start = time()
# some parameters
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
tokenizer_dir = args.tokenizer_dir
eval_data_dir = args.eval_data_dir
model_path = args.model_path
vocab_map_ids_path = args.vocab_map_ids_path
# mindspore context
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
print('Context setting is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# data process
tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
datasets_list, _ = load_datasets(
eval_data_dir, server_net_cfg.seq_length, tokenizer, eval_cfg.batch_size,
label_list=None,
do_shuffle=False,
drop_remainder=False,
output_dir=None)
print('Data process is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# main model
albert_model_cls = AlbertModelCLS(server_net_cfg)
albert_model_cls.set_train(False)
param_dict = load_checkpoint(model_path)
restore_params(albert_model_cls, param_dict)
print('Model construction is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# eval
callback = Accuracy()
global_step = 0
for datasets in datasets_list:
for batch in datasets.create_tuple_iterator():
input_ids, attention_mask, token_type_ids, label_ids, _ = batch
logits = albert_model_cls(input_ids, attention_mask, token_type_ids)
callback.update(logits, label_ids)
print('eval step: {}, {}: {}'.format(global_step, callback.name, callback.get_metrics()))
sys.stdout.flush()
global_step += 1
metrics = callback.get_metrics()
print('Final {}: {}'.format(callback.name, metrics))
sys.stdout.flush()
print('Evaluating process is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
if __name__ == '__main__':
args_opt = parse_args()
server_eval(args_opt)

View File

@ -0,0 +1,228 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import sys
from time import time
from mindspore import context
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from src.adam import AdamWeightDecayOp as AdamWeightDecay
from src.tokenization import CustomizedTextTokenizer
from src.config import train_cfg, server_net_cfg
from src.dataset import load_dataset
from src.utils import restore_params
from src.model import AlbertModelCLS
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
def parse_args():
"""
parse args
"""
parser = argparse.ArgumentParser(description='server task')
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--device_id', type=str, default='0')
parser.add_argument('--tokenizer_dir', type=str, default='../model_save/init/')
parser.add_argument('--server_data_path', type=str, default='../datasets/semi_supervise/server/train.txt')
parser.add_argument('--model_path', type=str, default='../model_save/init/albert_init.ckpt')
parser.add_argument('--output_dir', type=str, default='../model_save/train_server/')
parser.add_argument('--vocab_map_ids_path', type=str, default='../model_save/init/vocab_map_ids.txt')
parser.add_argument('--logging_step', type=int, default=1)
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--ms_role", type=str, default="MS_WORKER")
parser.add_argument("--worker_num", type=int, default=0)
parser.add_argument("--server_num", type=int, default=1)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
parser.add_argument("--dp_norm_clip", type=float, default=0.05)
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
return parser.parse_args()
def server_train(args):
start = time()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
tokenizer_dir = args.tokenizer_dir
server_data_path = args.server_data_path
model_path = args.model_path
output_dir = args.output_dir
vocab_map_ids_path = args.vocab_map_ids_path
logging_step = args.logging_step
device_target = args.device_target
server_mode = args.server_mode
ms_role = args.ms_role
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
scheduler_manage_port = args.scheduler_manage_port
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type
# Replace some parameters with federated learning parameters.
train_cfg.max_global_epoch = fl_iteration_num
fl_ctx = {
"enable_fl": True,
"server_mode": server_mode,
"ms_role": ms_role,
"worker_num": worker_num,
"server_num": server_num,
"scheduler_ip": scheduler_ip,
"scheduler_port": scheduler_port,
"fl_server_port": fl_server_port,
"start_fl_job_threshold": start_fl_job_threshold,
"start_fl_job_time_window": start_fl_job_time_window,
"update_model_ratio": update_model_ratio,
"update_model_time_window": update_model_time_window,
"fl_name": fl_name,
"fl_iteration_num": fl_iteration_num,
"client_epoch_num": client_epoch_num,
"client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"scheduler_manage_port": scheduler_manage_port,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,
"encrypt_type": encrypt_type
}
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# construct tokenizer
tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
print('Tokenizer construction is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# mindspore context
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True)
context.set_fl_context(**fl_ctx)
print('Context setting is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# construct model
albert_model_cls = AlbertModelCLS(server_net_cfg)
network_with_cls_loss = NetworkWithCLSLoss(albert_model_cls)
network_with_cls_loss.set_train(True)
print('Model construction is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# train prepare
global_step = 0
param_dict = load_checkpoint(model_path)
if 'learning_rate' in param_dict:
del param_dict['learning_rate']
# server optimizer
server_params = [_ for _ in network_with_cls_loss.trainable_params()
if 'word_embeddings' not in _.name
and 'postprocessor' not in _.name]
server_decay_params = list(
filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params)
)
server_other_params = list(
filter(lambda x: not train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter(x), server_params)
)
server_group_params = [
{'params': server_decay_params, 'weight_decay': train_cfg.optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': server_other_params, 'weight_decay': 0.0},
{'order_params': server_params}
]
server_optimizer = AdamWeightDecay(server_group_params,
learning_rate=train_cfg.server_cfg.learning_rate,
eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps)
server_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=server_optimizer)
restore_params(server_network_train_cell, param_dict)
print('Optimizer construction is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# server load data
server_train_dataset, _ = load_dataset(
server_data_path, server_net_cfg.seq_length, tokenizer, train_cfg.batch_size,
label_list=None,
do_shuffle=True,
drop_remainder=True,
output_dir=None,
cyclic_trunc=train_cfg.server_cfg.cyclic_trunc
)
print('Server data loading is done! Time cost: {}'.format(time() - start))
start = time()
# train process
for global_epoch in range(train_cfg.max_global_epoch):
for server_local_epoch in range(train_cfg.server_cfg.max_local_epoch):
for server_step, server_batch in enumerate(server_train_dataset.create_tuple_iterator()):
input_ids, attention_mask, token_type_ids, label_ids, _ = server_batch
model_start_time = time()
cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids)
time_cost = time() - model_start_time
if global_step % logging_step == 0:
print_text = 'server: '
print_text += 'global_epoch {}/{} '.format(global_epoch, train_cfg.max_global_epoch)
print_text += 'local_epoch {}/{} '.format(server_local_epoch, train_cfg.server_cfg.max_local_epoch)
print_text += 'local_step {}/{} '.format(server_step, server_train_dataset.get_dataset_size())
print_text += 'global_step {} cls_loss {} time_cost {}'.format(global_step, cls_loss, time_cost)
print(print_text)
sys.stdout.flush()
global_step += 1
del input_ids, attention_mask, token_type_ids, label_ids, _, cls_loss
output_path = os.path.join(
output_dir,
str(global_epoch*train_cfg.server_cfg.max_local_epoch+server_local_epoch)+'.ckpt'
)
save_checkpoint(server_network_train_cell.network, output_path)
print('Training process is done! Time cost: {}'.format(time() - start))
if __name__ == '__main__':
args_opt = parse_args()
server_train(args_opt)

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Finish train_cloud.py case")
parser.add_argument("--scheduler_port", type=int, default=8113)
args, _ = parser.parse_known_args()
scheduler_port = args.scheduler_port
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
subprocess.call(['bash', '-c', cmd])

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run train_cloud.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=0)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
scheduler_manage_port = args.scheduler_manage_port
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
cmd_sched += "python ${self_path}/../cloud_train.py"
cmd_sched += " --device_target=" + device_target
cmd_sched += " --server_mode=" + server_mode
cmd_sched += " --ms_role=MS_SCHED"
cmd_sched += " --worker_num=" + str(worker_num)
cmd_sched += " --server_num=" + str(server_num)
cmd_sched += " --scheduler_ip=" + scheduler_ip
cmd_sched += " --scheduler_port=" + str(scheduler_port)
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched])

View File

@ -0,0 +1,102 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run train_cloud.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=0)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
# The number of servers that this script will launch.
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
parser.add_argument("--dp_norm_clip", type=float, default=0.05)
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type
if local_server_num == -1:
local_server_num = server_num
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
for i in range(local_server_num):
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
cmd_server += "python ${self_path}/../cloud_train.py"
cmd_server += " --device_target=" + device_target
cmd_server += " --server_mode=" + server_mode
cmd_server += " --ms_role=MS_SERVER"
cmd_server += " --worker_num=" + str(worker_num)
cmd_server += " --server_num=" + str(server_num)
cmd_server += " --scheduler_ip=" + scheduler_ip
cmd_server += " --scheduler_port=" + str(scheduler_port)
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " > server.log 2>&1 &"
import time
time.sleep(0.3)
subprocess.call(['bash', '-c', cmd_server])

View File

@ -0,0 +1,424 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
import numpy as np
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn.optim.optimizer import Optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, mstype.int32)
_scaler_ten = Tensor(10, mstype.float32)
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
"""
Update parameters by AdamWeightDecay op.
"""
if optim_filter:
adam = P.AdamWeightDecay()
if decay_flags:
next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
else:
next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
return next_param
return gradient
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
"""
Update parameters.
Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
overflow (Tensor): Whether overflow occurs.
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
if optim_filter:
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_select = P.Select()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay, param_fp32) + update
update_with_lr = op_mul(lr, update)
zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return op_cast(next_param, F.dtype(param))
return gradient
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True
indices = gradient.indices
values = gradient.values
if ps_parameter and not cache_enable:
op_shape = P.Shape()
shapes = (op_shape(param), op_shape(m), op_shape(v),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices), shapes), param))
return success
if not target:
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices))
else:
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
scatter_add = P.ScatterAdd(use_locking)
assign_m = F.assign(m, op_mul(beta1, m))
assign_v = F.assign(v, op_mul(beta2, v))
grad_indices = gradient.indices
grad_value = gradient.values
next_m = scatter_add(m,
grad_indices,
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
next_v = scatter_add(v,
grad_indices,
op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
if use_nesterov:
m_temp = next_m * _scaler_ten
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
div_value = scatter_add(m,
op_mul(grad_indices, _scaler_one),
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
param_update = div_value / (op_sqrt(next_v) + eps)
m_recover = F.assign(m, m_temp / _scaler_ten)
F.control_depend(m_temp, assign_m_nesterov)
F.control_depend(assign_m_nesterov, div_value)
F.control_depend(param_update, m_recover)
else:
param_update = next_m / (op_sqrt(next_v) + eps)
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
next_param = param - lr_t * param_update
F.control_depend(assign_m, next_m)
F.control_depend(assign_v, next_v)
success = F.depend(success, F.assign(param, next_param))
success = F.depend(success, F.assign(m, next_m))
success = F.depend(success, F.assign(v, next_v))
return success
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
moment1, moment2, ps_parameter, cache_enable):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
if ps_parameter and not cache_enable:
op_shape = P.Shape()
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
else:
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
return success
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
success = True
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
success = F.depend(success, F.assign_add(param, delat_param))
return success
def _check_param_value(beta1, beta2, eps, prim_name):
"""Check the type of inputs."""
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_positive_float(eps, "eps", prim_name)
class AdamWeightDecayForBert(Optimizer):
"""
Implements the Adam algorithm to fix the weight decay.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value must be a list of `Parameter`.
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
which in the 'order_params' must be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
- **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
Outputs:
tuple[bool], all elements are True.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = AdamWeightDecay(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
_check_param_value(beta1, beta2, eps, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.hyper_map = C.HyperMap()
self.op_select = P.Select()
self.op_cast = P.Cast()
self.op_reshape = P.Reshape()
self.op_shape = P.Shape()
def construct(self, gradients, overflow):
"""AdamWeightDecayForBert"""
lr = self.get_lr()
cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
self.op_reshape(overflow, (())), mstype.bool_)
beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result
class AdamWeightDecayOp(Optimizer):
"""
Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value must be a list of `Parameter`.
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
which in the 'order_params' must be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[bool], all elements are True.
Supported Platforms:
``GPU``
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = AdamWeightDecayOp(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
_check_param_value(beta1, beta2, eps, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.hyper_map = C.HyperMap()
def construct(self, gradients):
"""AdamWeightDecayOp"""
lr = self.get_lr()
if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result

View File

@ -0,0 +1,138 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""assessment methods"""
import numpy as np
class Accuracy:
"""Accuracy"""
def __init__(self):
self.acc_num = 0
self.total_num = 0
self.name = 'Accuracy'
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1)
self.acc_num += np.sum(labels == logit_id)
self.total_num += len(labels)
def get_metrics(self):
return self.acc_num / self.total_num * 100.0
class TopK:
"""F1"""
def __init__(self, k=5):
self.acc_num = 0
self.total_num = 0
self.k = k
self.name = 'Top' + str(k)
def update(self, logits, labels):
labels = labels.asnumpy()
logits = logits.asnumpy()
sorted_index = logits.argsort()
for i, label in enumerate(labels):
for j in range(self.k):
if sorted_index[i, -j-1] == label:
self.acc_num += 1
break
self.total_num += len(labels)
def get_metrics(self):
return self.acc_num / self.total_num * 100.0
class F1:
"""F1"""
def __init__(self):
self.logits_array = np.array([])
self.labels_array = np.array([])
self.name = 'F1'
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logits = np.argmax(logits, axis=1)
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
def get_metrics(self):
if len(self.labels_array) < 2:
return 0.0
tp = np.sum(self.labels_array & self.logits_array)
fp = np.sum(self.labels_array & (~self.logits_array))
fn = np.sum((~self.labels_array) & self.logits_array)
p = tp / (tp + fp)
r = tp / (tp + fn)
return 2.0 * p * r / (p + r) * 100.0
class Pearsonr:
"""Pearsonr"""
def __init__(self):
self.logits_array = np.array([])
self.labels_array = np.array([])
self.name = 'Pearsonr'
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logits = np.reshape(logits, -1)
self.labels_array = np.concatenate([self.labels_array, labels])
self.logits_array = np.concatenate([self.logits_array, logits])
def get_metrics(self):
if len(self.labels_array) < 2:
return 0.0
x_mean = self.logits_array.mean()
y_mean = self.labels_array.mean()
xm = self.logits_array - x_mean
ym = self.labels_array - y_mean
norm_xm = np.linalg.norm(xm)
norm_ym = np.linalg.norm(ym)
return np.dot(xm / norm_xm, ym / norm_ym) * 100.0
class Matthews:
"""Matthews"""
def __init__(self):
self.logits_array = np.array([])
self.labels_array = np.array([])
self.name = 'Matthews'
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logits = np.argmax(logits, axis=1)
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
def get_metrics(self):
if len(self.labels_array) < 2:
return 0.0
tp = np.sum(self.labels_array & self.logits_array)
fp = np.sum(self.labels_array & (~self.logits_array))
fn = np.sum((~self.labels_array) & self.logits_array)
tn = np.sum((~self.labels_array) & (~self.logits_array))
return (tp * tn - fp * fn) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) * 100.0

View File

@ -0,0 +1,299 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.parameter import ParameterTuple
class ClipByNorm(nn.Cell):
"""
Clips tensor values to a maximum :math:`L_2`-norm.
"""
def __init__(self):
super(ClipByNorm, self).__init__()
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.select_ = P.Select()
self.greater_ = P.Greater()
self.cast = P.Cast()
self.sqrt = P.Sqrt()
self.max_op = P.Maximum()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.fill = P.Fill()
self.expand_dims = P.ExpandDims()
self.dtype = P.DType()
def construct(self, x, clip_norm):
"""add ms_function decorator for pynative mode"""
mul_x = F.square(x)
if mul_x.shape == (1,):
l2sum = self.cast(mul_x, mstype.float32)
else:
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
cond = self.greater_(l2sum, 0)
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
intermediate = x * clip_norm
max_norm = self.max_op(l2norm, clip_norm)
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
values_clip = self.reshape(values_clip, self.shape(x))
values_clip = F.identity(values_clip)
return values_clip
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in (0, 1):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class ClipGradients(nn.Cell):
"""
Clip gradients.
Inputs:
grads (list): List of gradient tuples.
clip_type (Tensor): The way to clip, 'value' or 'norm'.
clip_value (Tensor): Specifies how much to clip.
Returns:
List, a list of clipped_grad tuples.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self,
grads,
clip_type,
clip_value):
"""clip gradients"""
if clip_type not in (0, 1):
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
new_grads = new_grads + (t,)
return new_grads
class CrossEntropy(nn.Cell):
"""
Cross Entropy loss
"""
def __init__(self, num_labels):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
self.num_labels = num_labels
def construct(self, logits, label_ids):
label_ids = self.reshape(label_ids, self.last_idx)
one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
loss = self.reduce_mean(per_example_loss, self.last_idx)
return_value = self.cast(loss, mstype.float32)
return return_value
class NetworkWithCLSLoss(nn.Cell):
def __init__(self, network):
super(NetworkWithCLSLoss, self).__init__(auto_prefix=False)
self.cls_network = network
self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def construct(self, input_ids, input_mask, token_type_id, label_ids):
logits = self.cls_network(input_ids, input_mask, token_type_id)
cls_loss = self.loss_fct(logits, label_ids)
return cls_loss
class NetworkWithMLMLoss(nn.Cell):
def __init__(self, network, vocab_size=21128):
super(NetworkWithMLMLoss, self).__init__(auto_prefix=False)
self.mlm_network = network
self.vocab_size = vocab_size
self.reshape = P.Reshape()
self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def construct(self, input_ids, input_mask, token_type_id, label_ids):
prediction_scores = self.mlm_network(input_ids, input_mask, token_type_id)
prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size))
label_ids = self.reshape(label_ids, (-1,))
mlm_loss = self.loss_fct(prediction_scores, label_ids)
return mlm_loss
class NetworkTrainCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
super(NetworkTrainCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.sens = sens
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.clip_type = 1
self.clip_value = 1.0
self.cast = P.Cast()
self.hyper_map = C.HyperMap()
self.get_weights_by_key = P.PullWeight()
self.over_weights_by_key = P.PushWeight()
self.get_weights_by_key_input_1, \
self.get_weights_by_key_input_2, \
self.get_weights_by_key_input_3 = self._get_weights_by_key_inputs(self.network.parameters_and_names())
self.over_weights_by_key_input_1, \
self.over_weights_by_key_input_2, \
self.over_weights_by_key_input_3 = self._over_weights_by_key_inputs(self.network.parameters_and_names())
def _communication_with_server_1(self, weights):
result = self.hyper_map(F.partial(self.get_weights_by_key), weights,
self.get_weights_by_key_input_2, self.get_weights_by_key_input_3)
return result
def _communication_with_server_2(self, weights):
result = self.hyper_map(F.partial(self.over_weights_by_key), weights,
self.over_weights_by_key_input_2,
self.over_weights_by_key_input_3)
return result
def _get_weights_by_key_inputs(self, weights):
filtered_weights = []
weight_names = []
weight_indices = []
index = 0
for weight in weights:
if weight[1].pull_weight_from_server:
filtered_weights.append(weight[1])
weight_names.append(weight[1].name)
weight_indices.append(index)
index += 1
return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices)
def _over_weights_by_key_inputs(self, weights):
filtered_weights = []
weight_names = []
weight_indices = []
index = 0
for weight in weights:
if weight[1].push_weight_to_server:
filtered_weights.append(weight[1])
weight_names.append(weight[1].name)
weight_indices.append(index)
index += 1
return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices)
def construct(self, input_ids, input_mask, token_type_id, label_ids):
weights = self.weights
res = self._communication_with_server_1(self.get_weights_by_key_input_1)
input_ids = F.depend(input_ids, res)
loss = self.network(input_ids, input_mask, token_type_id, label_ids)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
label_ids,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
loss = F.depend(loss, self.optimizer(grads))
weights1 = F.depend(self.over_weights_by_key_input_1, loss)
loss = F.depend(loss, self._communication_with_server_2(weights1))
return loss
class NetworkNoClientTrainCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
super(NetworkNoClientTrainCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.sens = sens
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.clip_type = 1
self.clip_value = 1.0
self.cast = P.Cast()
self.hyper_map = C.HyperMap()
def construct(self, input_ids, input_mask, token_type_id, label_ids):
weights = self.weights
loss = self.network(input_ids, input_mask, token_type_id, label_ids)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
label_ids,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)

View File

@ -0,0 +1,132 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config script"""
from easydict import EasyDict as edict
from mindspore.common import dtype as mstype
from src.model import AlbertConfig
gradient_cfg = edict({
'clip_type': 1,
'clip_value': 1.0
})
train_cfg = edict({
'batch_size': 16,
'loss_scale_value': 2 ** 16,
'scale_factor': 2,
'scale_window': 50,
'max_global_epoch': 10, #fl_iteration_num
'server_cfg': edict({
'learning_rate': 1e-5,
'max_local_epoch': 1,
'cyclic_trunc': False
}),
'client_cfg': edict({
'learning_rate': 1e-5,
'max_local_epoch': 1,
'num_per_epoch': 20,
'cyclic_trunc': True
}),
'optimizer_cfg': edict({
'AdamWeightDecay': edict({
'end_learning_rate': 1e-14,
'power': 1.0,
'weight_decay': 1e-4,
'eps': 1e-6,
'decay_filter': lambda x: 'norm' not in x.name.lower() and 'bias' not in x.name.lower(),
'warmup_ratio': 0.1
}),
}),
})
eval_cfg = edict({
'batch_size': 256,
})
server_net_cfg = AlbertConfig(
seq_length=8,
vocab_size=11682,
hidden_size=312,
num_hidden_groups=1,
num_hidden_layers=4,
inner_group_num=1,
num_attention_heads=12,
intermediate_size=1248,
hidden_act="gelu",
query_act=None,
key_act=None,
value_act=None,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
classifier_dropout_prob=0.0,
embedding_size=128,
layer_norm_eps=1e-12,
has_attention_mask=True,
do_return_2d_tensor=True,
use_one_hot_embeddings=False,
use_token_type=True,
return_all_encoders=False,
output_attentions=False,
output_hidden_states=False,
dtype=mstype.float32,
compute_type=mstype.float32,
is_training=True,
num_labels=4,
use_word_embeddings=True
)
client_net_cfg = AlbertConfig(
seq_length=8,
vocab_size=11682,
hidden_size=312,
num_hidden_groups=1,
num_hidden_layers=4,
inner_group_num=1,
num_attention_heads=12,
intermediate_size=1248,
hidden_act="gelu",
query_act=None,
key_act=None,
value_act=None,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
classifier_dropout_prob=0.0,
embedding_size=128,
layer_norm_eps=1e-12,
has_attention_mask=True,
do_return_2d_tensor=True,
use_one_hot_embeddings=False,
use_token_type=True,
return_all_encoders=False,
output_attentions=False,
output_hidden_states=False,
dtype=mstype.float32,
compute_type=mstype.float32,
is_training=True,
num_labels=4,
use_word_embeddings=True
)

View File

@ -0,0 +1,187 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pickle
import numpy as np
from mindspore import dataset as ds
from mindspore.dataset.transforms import c_transforms as C
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
class InputFeatures:
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_id, seq_length=None):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.seq_length = seq_length
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, cyclic_trunc=False):
"""Loads a data file into a list of `InputBatch`s."""
label_map = {label: _ for _, label in enumerate(label_list)}
features = []
for example in examples:
tokens = tokenizer.tokenize(example[0])
seq_length = len(tokens)
if seq_length > max_seq_length - 2:
if cyclic_trunc:
rand_index = np.random.randint(0, seq_length)
tokens = [tokens[_] if _ < seq_length else tokens[_ - seq_length]
for _ in range(rand_index, rand_index + max_seq_length - 2)]
else:
tokens = tokens[: (max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
segment_ids = [0] * len(tokens)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
seq_length = len(input_ids)
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
label_id = label_map[example[1]]
features.append(InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id,
seq_length=seq_length))
return features
def load_dataset(data_path, max_seq_length, tokenizer, batch_size, label_list=None, do_shuffle=True,
drop_remainder=True, output_dir=None, i=0, cyclic_trunc=False):
if label_list is None:
label_list = ['good', 'leimu', 'xiaoku', 'xin']
with open(data_path, 'r', encoding='utf-8') as f:
data = f.read()
data_list = data.split('\n<<<')
input_list = []
for key in data_list[1:]:
key = key.split('>>>')
input_list.append([key[1], key[0]])
datasets = create_ms_dataset(input_list, label_list, max_seq_length, tokenizer, batch_size,
do_shuffle=do_shuffle, drop_remainder=drop_remainder, cyclic_trunc=cyclic_trunc)
if output_dir is not None:
output_path = os.path.join(output_dir, str(i) + '.dat')
print(output_path)
with open(output_path, "wb") as f:
pickle.dump(tuple(datasets), f)
del data, data_list, input_list
return datasets, len(label_list)
def load_datasets(data_dir, max_seq_length, tokenizer, batch_size, label_list=None, do_shuffle=True,
drop_remainder=True, output_dir=None, cyclic_trunc=False):
if label_list is None:
label_list = ['good', 'leimu', 'xiaoku', 'xin']
data_path_list = os.listdir(data_dir)
datasets_list = []
for i, relative_path in enumerate(data_path_list):
data_path = os.path.join(data_dir, relative_path)
with open(data_path, 'r', encoding='utf-8') as f:
data = f.read()
data_list = data.split('\n<<<')
input_list = []
for key in data_list[1:]:
key = key.split('>>>')
input_list.append([key[1], key[0]])
datasets = create_ms_dataset(input_list, label_list, max_seq_length, tokenizer, batch_size,
do_shuffle=do_shuffle, drop_remainder=drop_remainder, cyclic_trunc=cyclic_trunc)
if output_dir is not None:
output_path = os.path.join(output_dir, str(i) + '.dat')
print(output_path)
with open(output_path, "wb") as f:
pickle.dump(tuple(datasets.create_tuple_iterator()), f)
datasets_list.append(datasets)
return datasets_list, len(label_list)
def create_ms_dataset(data_list, label_list, max_seq_length, tokenizer, batch_size, do_shuffle=True,
drop_remainder=True, cyclic_trunc=False):
features = convert_examples_to_features(data_list, label_list, max_seq_length, tokenizer,
cyclic_trunc=cyclic_trunc)
def generator_func():
for feature in features:
yield (np.array(feature.input_ids),
np.array(feature.input_mask),
np.array(feature.segment_ids),
np.array(feature.label_id),
np.array(feature.seq_length))
dataset = ds.GeneratorDataset(generator_func,
['input_ids', 'input_mask', 'token_type_id', 'label_ids', 'seq_length'])
if do_shuffle:
dataset = dataset.shuffle(buffer_size=10000)
type_cast_op = C.TypeCast(mstype.int32)
dataset = dataset.map(operations=[type_cast_op])
dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
return dataset
class ConstructMaskAndReplaceTensor:
def __init__(self, batch_size, max_seq_length, vocab_size, keep_first_unchange=True, keep_last_unchange=True):
self.batch_size = batch_size
self.max_seq_length = max_seq_length
self.vocab_size = vocab_size
self.keep_first_unchange = keep_first_unchange
self.keep_last_unchange = keep_last_unchange
self.mask_tensor = np.ones((self.batch_size, self.max_seq_length))
self.replace_tensor = np.zeros((self.batch_size, self.max_seq_length))
def construct(self, seq_lengths):
for i in range(self.batch_size):
for j in range(seq_lengths[i]):
rand1 = np.random.random()
if rand1 < 0.15:
self.mask_tensor[i, j] = 0
rand2 = np.random.random()
if rand2 < 0.8:
self.replace_tensor[i, j] = 103
elif rand2 < 0.9:
self.mask_tensor[i, j] = 1
else:
self.replace_tensor[i, j] = np.random.randint(0, self.vocab_size)
else:
self.mask_tensor[i, j] = 1
self.replace_tensor[i, j] = 0
for j in range(seq_lengths[i], self.max_seq_length):
self.mask_tensor[i, j] = 1
self.replace_tensor[i, j] = 0
if self.keep_first_unchange:
self.mask_tensor[i, 0] = 1
self.replace_tensor[i, 0] = 0
if self.keep_last_unchange:
self.mask_tensor[i, seq_lengths[i] - 1] = 1
self.replace_tensor[i, seq_lengths[i] - 1] = 0
mask_tensor = Tensor(self.mask_tensor, dtype=mstype.int32)
replace_tensor = Tensor(self.replace_tensor, dtype=mstype.int32)
return mask_tensor, replace_tensor

View File

@ -0,0 +1,892 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import copy
import numpy as np
from mindspore import nn
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
class AlbertConfig:
"""
Configuration for `AlbertModel`.
Args:
seq_length (int): Length of input sequence. Default: 128.
vocab_size (int): The shape of each embedding vector. Default: 32000.
hidden_size (int): Size of the bert encoder layers. Default: 768.
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
cell. Default: 12.
num_attention_heads (int): Number of attention heads in the BertTransformer
encoder cell. Default: 12.
intermediate_size (int): Size of intermediate layer in the BertTransformer
encoder cell. Default: 3072.
hidden_act (str): Activation function used in the BertTransformer encoder
cell. Default: "gelu".
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 512.
type_vocab_size (int): Size of token type vocab. Default: 16.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
seq_length=256,
vocab_size=21128,
hidden_size=312,
num_hidden_groups=1,
num_hidden_layers=4,
inner_group_num=1,
num_attention_heads=12,
intermediate_size=1248,
hidden_act="gelu",
query_act=None,
key_act=None,
value_act=None,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
classifier_dropout_prob=0.1,
embedding_size=128,
layer_norm_eps=1e-12,
has_attention_mask=True,
do_return_2d_tensor=True,
use_one_hot_embeddings=False,
use_token_type=True,
return_all_encoders=False,
output_attentions=False,
output_hidden_states=False,
dtype=mstype.float32,
compute_type=mstype.float32,
is_training=True,
num_labels=5,
use_word_embeddings=True):
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.inner_group_num = inner_group_num
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.query_act = query_act
self.key_act = key_act
self.value_act = value_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.use_relative_positions = use_relative_positions
self.classifier_dropout_prob = classifier_dropout_prob
self.embedding_size = embedding_size
self.layer_norm_eps = layer_norm_eps
self.num_hidden_groups = num_hidden_groups
self.has_attention_mask = has_attention_mask
self.do_return_2d_tensor = do_return_2d_tensor
self.use_one_hot_embeddings = use_one_hot_embeddings
self.use_token_type = use_token_type
self.return_all_encoders = return_all_encoders
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.dtype = dtype
self.compute_type = compute_type
self.is_training = is_training
self.num_labels = num_labels
self.use_word_embeddings = use_word_embeddings
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(EmbeddingLookup, self).__init__()
self.vocab_size = config.vocab_size
self.use_one_hot_embeddings = config.use_one_hot_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(config.initializer_range),
[config.vocab_size, config.embedding_size]),
name='embedding_table')
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.Gather()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = (-1, config.seq_length, config.embedding_size)
def construct(self, input_ids):
"""embedding lookup"""
flat_ids = self.reshape(input_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(
one_hot_ids, self.embedding_table)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, self.shape)
return output, self.embedding_table
class EmbeddingPostprocessor(nn.Cell):
"""
Postprocessors apply positional and token type embeddings to word embeddings.
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(EmbeddingPostprocessor, self).__init__()
self.use_token_type = config.use_token_type
self.token_type_vocab_size = config.type_vocab_size
self.use_one_hot_embeddings = config.use_one_hot_embeddings
self.max_position_embeddings = config.max_position_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(config.initializer_range),
[config.type_vocab_size,
config.embedding_size]))
self.shape_flat = (-1,)
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.1, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = (-1, config.seq_length, config.embedding_size)
self.layernorm = nn.LayerNorm((config.embedding_size,))
self.dropout = nn.Dropout(1 - config.hidden_dropout_prob)
self.gather = P.Gather()
self.use_relative_positions = config.use_relative_positions
self.slice = P.StridedSlice()
self.full_position_embeddings = Parameter(initializer
(TruncatedNormal(config.initializer_range),
[config.max_position_embeddings,
config.embedding_size]))
def construct(self, token_type_ids, word_embeddings):
"""embedding postprocessor"""
output = word_embeddings
if self.use_token_type:
flat_ids = self.reshape(token_type_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids,
self.token_type_vocab_size, self.on_value, self.off_value)
token_type_embeddings = self.array_mul(one_hot_ids,
self.embedding_table)
else:
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
output += token_type_embeddings
if not self.use_relative_positions:
_, seq, width = self.shape
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
output = self.layernorm(output)
output = self.dropout(output)
return output
class AlbertOutput(nn.Cell):
"""
Apply a linear computation to hidden status and a residual computation to input.
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(AlbertOutput, self).__init__()
self.dense = nn.Dense(config.hidden_size, config.hidden_size,
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
self.dropout = nn.Dropout(1 - config.hidden_dropout_prob)
self.add = P.Add()
self.is_gpu = context.get_context('device_target') == "GPU"
if self.is_gpu:
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(mstype.float32)
self.compute_type = config.compute_type
else:
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):
"""bert output"""
output = self.dense(hidden_status)
output = self.dropout(output)
output = self.add(input_tensor, output)
output = self.layernorm(output)
if self.is_gpu:
output = self.cast(output, self.compute_type)
return output
class RelaPosMatrixGenerator(nn.Cell):
"""
Generates matrix of relative positions between inputs.
Args:
length (int): Length of one dim for the matrix to be generated.
max_relative_position (int): Max value of relative position.
"""
def __init__(self, length, max_relative_position):
super(RelaPosMatrixGenerator, self).__init__()
self._length = length
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
self.range_length = -length + 1
self.tile = P.Tile()
self.range_mat = P.Reshape()
self.sub = P.Sub()
self.expanddims = P.ExpandDims()
self.cast = P.Cast()
def construct(self):
"""position matrix generator"""
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
tile_row_out = self.tile(range_vec_row_out, (self._length,))
tile_col_out = self.tile(range_vec_col_out, (1, self._length))
range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
distance_mat = self.sub(range_mat_out, transpose_out)
distance_mat_clipped = C.clip_by_value(distance_mat,
self._min_relative_position,
self._max_relative_position)
# Shift values to be >=0. Each integer still uniquely identifies a
# relative position difference.
final_mat = distance_mat_clipped + self._max_relative_position
return final_mat
class RelaPosEmbeddingsGenerator(nn.Cell):
"""
Generates tensor of size [length, length, depth].
Args:
length (int): Length of one dim for the matrix to be generated.
depth (int): Size of each attention head.
max_relative_position (int): Maxmum value of relative position.
initializer_range (float): Initialization value of TruncatedNormal.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
length,
depth,
max_relative_position,
initializer_range,
use_one_hot_embeddings=False):
super(RelaPosEmbeddingsGenerator, self).__init__()
self.depth = depth
self.vocab_size = max_relative_position * 2 + 1
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embeddings_table = Parameter(
initializer(TruncatedNormal(initializer_range),
[self.vocab_size, self.depth]),
name='embeddings_for_position')
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
max_relative_position=max_relative_position)
self.reshape = P.Reshape()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.shape = P.Shape()
self.gather = P.Gather() # index_select
self.matmul = P.BatchMatMul()
def construct(self):
"""position embedding generation"""
relative_positions_matrix_out = self.relative_positions_matrix()
# Generate embedding for each relative position of dimension depth.
if self.use_one_hot_embeddings:
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
one_hot_relative_positions_matrix = self.one_hot(
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
embeddings = self.reshape(embeddings, my_shape)
else:
embeddings = self.gather(self.embeddings_table,
relative_positions_matrix_out, 0)
return embeddings
class SaturateCast(nn.Cell):
"""
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
the danger that the value will overflow or underflow.
Args:
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
"""
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(SaturateCast, self).__init__()
np_type = mstype.dtype_to_nptype(dst_type)
min_type = np.finfo(np_type).min
max_type = np.finfo(np_type).max
self.tensor_min_type = Tensor([min_type], dtype=src_type)
self.tensor_max_type = Tensor([max_type], dtype=src_type)
self.min_op = P.Minimum()
self.max_op = P.Maximum()
self.cast = P.Cast()
self.dst_type = dst_type
def construct(self, x):
"""saturate cast"""
out = self.max_op(x, self.tensor_min_type)
out = self.min_op(out, self.tensor_max_type)
return self.cast(out, self.dst_type)
class AlbertAttention(nn.Cell):
"""
Apply multi-headed attention from "from_tensor" to "to_tensor".
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(AlbertAttention, self).__init__()
self.from_seq_length = config.seq_length
self.to_seq_length = config.seq_length
self.num_attention_heads = config.num_attention_heads
self.size_per_head = int(config.hidden_size / config.num_attention_heads)
self.has_attention_mask = config.has_attention_mask
self.use_relative_positions = config.use_relative_positions
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=config.compute_type)
self.reshape = P.Reshape()
self.shape_from_2d = (-1, config.hidden_size)
self.shape_to_2d = (-1, config.hidden_size)
weight = TruncatedNormal(config.initializer_range)
self.query = nn.Dense(config.hidden_size,
config.hidden_size,
activation=config.query_act,
weight_init=weight).to_float(config.compute_type)
self.key = nn.Dense(config.hidden_size,
config.hidden_size,
activation=config.key_act,
weight_init=weight).to_float(config.compute_type)
self.value = nn.Dense(config.hidden_size,
config.hidden_size,
activation=config.value_act,
weight_init=weight).to_float(config.compute_type)
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.matmul = P.BatchMatMul()
self.shape_from = (-1, config.seq_length, config.num_attention_heads, self.size_per_head)
self.shape_to = (-1, config.seq_length, config.num_attention_heads, self.size_per_head)
self.multiply = P.Mul()
self.transpose = P.Transpose()
self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0], dtype=config.compute_type)
self.softmax = nn.Softmax()
self.dropout = nn.Dropout(1 - config.attention_probs_dropout_prob)
if self.has_attention_mask:
self.expand_dims = P.ExpandDims()
self.sub = P.Sub()
self.add = P.Add()
self.cast = P.Cast()
self.get_dtype = P.DType()
if config.do_return_2d_tensor:
self.shape_return = (-1, config.hidden_size)
else:
self.shape_return = (-1, config.seq_length, config.hidden_size)
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
if self.use_relative_positions:
self._generate_relative_positions_embeddings = \
RelaPosEmbeddingsGenerator(length=config.seq_length,
depth=self.size_per_head,
max_relative_position=16,
initializer_range=config.initializer_range,
use_one_hot_embeddings=config.use_one_hot_embeddings)
def construct(self, from_tensor, to_tensor, attention_mask):
"""bert attention"""
# reshape 2d/3d input tensors to 2d
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
query_out = self.query(from_tensor_2d)
key_out = self.key(to_tensor_2d)
value_out = self.value(to_tensor_2d)
query_layer = self.reshape(query_out, self.shape_from)
query_layer = self.transpose(query_layer, self.trans_shape)
key_layer = self.reshape(key_out, self.shape_to)
key_layer = self.transpose(key_layer, self.trans_shape)
attention_scores = self.matmul_trans_b(query_layer, key_layer)
# use_relative_position, supplementary logic
if self.use_relative_positions:
# relations_keys is [F|T, F|T, H]
relations_keys = self._generate_relative_positions_embeddings()
relations_keys = self.cast_compute_type(relations_keys)
# query_layer_t is [F, B, N, H]
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
# query_layer_r is [F, B * N, H]
query_layer_r = self.reshape(query_layer_t,
(self.from_seq_length,
-1,
self.size_per_head))
# key_position_scores is [F, B * N, F|T]
key_position_scores = self.matmul_trans_b(query_layer_r,
relations_keys)
# key_position_scores_r is [F, B, N, F|T]
key_position_scores_r = self.reshape(key_position_scores,
(self.from_seq_length,
-1,
self.num_attention_heads,
self.from_seq_length))
# key_position_scores_r_t is [B, N, F, F|T]
key_position_scores_r_t = self.transpose(key_position_scores_r,
self.trans_shape_position)
attention_scores = attention_scores + key_position_scores_r_t
attention_scores = self.multiply(self.scores_mul, attention_scores)
if self.has_attention_mask:
attention_mask = self.expand_dims(attention_mask, 1)
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
self.cast(attention_mask, self.get_dtype(attention_scores)))
adder = self.multiply(multiply_out, self.multiply_data)
attention_scores = self.add(adder, attention_scores)
attention_probs = self.softmax(attention_scores)
attention_probs = self.dropout(attention_probs)
value_layer = self.reshape(value_out, self.shape_to)
value_layer = self.transpose(value_layer, self.trans_shape)
context_layer = self.matmul(attention_probs, value_layer)
# use_relative_position, supplementary logic
if self.use_relative_positions:
# relations_values is [F|T, F|T, H]
relations_values = self._generate_relative_positions_embeddings()
relations_values = self.cast_compute_type(relations_values)
# attention_probs_t is [F, B, N, T]
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
# attention_probs_r is [F, B * N, T]
attention_probs_r = self.reshape(
attention_probs_t,
(self.from_seq_length,
-1,
self.to_seq_length))
# value_position_scores is [F, B * N, H]
value_position_scores = self.matmul(attention_probs_r,
relations_values)
# value_position_scores_r is [F, B, N, H]
value_position_scores_r = self.reshape(value_position_scores,
(self.from_seq_length,
-1,
self.num_attention_heads,
self.size_per_head))
# value_position_scores_r_t is [B, N, F, H]
value_position_scores_r_t = self.transpose(value_position_scores_r,
self.trans_shape_position)
context_layer = context_layer + value_position_scores_r_t
context_layer = self.transpose(context_layer, self.trans_shape)
context_layer = self.reshape(context_layer, self.shape_return)
return context_layer, attention_scores
class AlbertSelfAttention(nn.Cell):
"""
Apply self-attention.
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(AlbertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
"of attention heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.attention = AlbertAttention(config)
self.output = AlbertOutput(config)
self.reshape = P.Reshape()
self.shape = (-1, config.hidden_size)
def construct(self, input_tensor, attention_mask):
"""bert self attention"""
input_tensor = self.reshape(input_tensor, self.shape)
attention_output, attention_scores = self.attention(input_tensor, input_tensor, attention_mask)
output = self.output(attention_output, input_tensor)
return output, attention_scores
class AlbertEncoderCell(nn.Cell):
"""
Encoder cells used in BertTransformer.
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(AlbertEncoderCell, self).__init__()
self.attention = AlbertSelfAttention(config)
self.intermediate = nn.Dense(in_channels=config.hidden_size,
out_channels=config.intermediate_size,
activation=config.hidden_act,
weight_init=TruncatedNormal(config.initializer_range)
).to_float(config.compute_type)
self.output = AlbertOutput(config)
def construct(self, hidden_states, attention_mask):
"""bert encoder cell"""
# self-attention
attention_output, attention_scores = self.attention(hidden_states, attention_mask)
# feed construct
intermediate_output = self.intermediate(attention_output)
# add and normalize
output = self.output(intermediate_output, attention_output)
return output, attention_scores
class AlbertLayer(nn.Cell):
"""
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(AlbertLayer, self).__init__()
self.output_attentions = config.output_attentions
self.attention = AlbertSelfAttention(config)
self.ffn = nn.Dense(config.hidden_size,
config.intermediate_size,
activation=config.hidden_act).to_float(config.compute_type)
self.ffn_output = nn.Dense(config.intermediate_size, config.hidden_size)
self.full_layer_layer_norm = nn.LayerNorm((config.hidden_size,))
self.shape = (-1, config.seq_length, config.hidden_size)
self.reshape = P.Reshape()
def construct(self, hidden_states, attention_mask):
attention_output, attention_scores = self.attention(hidden_states, attention_mask)
ffn_output = self.ffn(attention_output)
ffn_output = self.ffn_output(ffn_output)
ffn_output = self.reshape(ffn_output + attention_output, self.shape)
hidden_states = self.full_layer_layer_norm(ffn_output)
return hidden_states, attention_scores
class AlbertLayerGroup(nn.Cell):
"""
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(AlbertLayerGroup, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.albert_layers = nn.CellList([AlbertLayer(config) for _ in range(config.inner_group_num)])
def construct(self, hidden_states, attention_mask):
layer_hidden_states = ()
layer_attentions = ()
for _, albert_layer in enumerate(self.albert_layers):
layer_output = albert_layer(hidden_states, attention_mask)
hidden_states = layer_output[0]
if self.output_attentions:
layer_attentions = layer_attentions + (layer_output[1],)
if self.output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_attentions:
outputs = outputs + (layer_attentions,)
if self.output_hidden_states:
outputs = outputs + (layer_hidden_states,)
return outputs
class AlbertTransformer(nn.Cell):
"""
Multi-layer bert transformer.
Args:
config (AlbertConfig): Albert Config.
"""
def __init__(self, config):
super(AlbertTransformer, self).__init__()
self.num_hidden_layers = config.num_hidden_layers
self.num_hidden_groups = config.num_hidden_groups
self.group_idx_list = [int(_ / (config.num_hidden_layers / config.num_hidden_groups))
for _ in range(config.num_hidden_layers)]
self.embedding_hidden_mapping_in = nn.Dense(config.embedding_size, config.hidden_size)
self.return_all_encoders = config.return_all_encoders
layers = []
for _ in range(config.num_hidden_groups):
layer = AlbertLayerGroup(config)
layers.append(layer)
self.albert_layer_groups = nn.CellList(layers)
self.reshape = P.Reshape()
self.shape = (-1, config.embedding_size)
self.out_shape = (-1, config.seq_length, config.hidden_size)
def construct(self, input_tensor, attention_mask):
"""bert transformer"""
prev_output = self.reshape(input_tensor, self.shape)
prev_output = self.embedding_hidden_mapping_in(prev_output)
all_encoder_layers = ()
all_encoder_atts = ()
all_encoder_outputs = (prev_output,)
# for layer_module in self.layers:
for i in range(self.num_hidden_layers):
# Index of the hidden group
group_idx = self.group_idx_list[i]
layer_output, encoder_att = self.albert_layer_groups[group_idx](prev_output, attention_mask)
prev_output = layer_output
if self.return_all_encoders:
all_encoder_outputs += (layer_output,)
layer_output = self.reshape(layer_output, self.out_shape)
all_encoder_layers += (layer_output,)
all_encoder_atts += (encoder_att,)
if not self.return_all_encoders:
prev_output = self.reshape(prev_output, self.out_shape)
all_encoder_layers += (prev_output,)
return prev_output
class CreateAttentionMaskFromInputMask(nn.Cell):
"""
Create attention mask according to input mask.
Args:
config (Class): Configuration for BertModel.
"""
def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.shape = (-1, 1, config.seq_length)
def construct(self, input_mask):
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
return attention_mask
class AlbertModel(nn.Cell):
"""
Bidirectional Encoder Representations from Transformers.
Args:
config (Class): Configuration for BertModel.
"""
def __init__(self, config):
super(AlbertModel, self).__init__()
config = copy.deepcopy(config)
if not config.is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embedding_size = config.hidden_size
self.token_type_ids = None
self.last_idx = self.num_hidden_layers - 1
self.use_word_embeddings = config.use_word_embeddings
if self.use_word_embeddings:
self.word_embeddings = EmbeddingLookup(config)
self.embedding_postprocessor = EmbeddingPostprocessor(config)
self.encoder = AlbertTransformer(config)
self.cast = P.Cast()
self.dtype = config.dtype
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
self.slice = P.StridedSlice()
self.squeeze_1 = P.Squeeze(axis=1)
self.pooler = nn.Dense(self.hidden_size, self.hidden_size,
activation="tanh",
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
def construct(self, input_ids, token_type_ids, input_mask):
"""bert model"""
# embedding
if self.use_word_embeddings:
word_embeddings, _ = self.word_embeddings(input_ids)
else:
word_embeddings = input_ids
embedding_output = self.embedding_postprocessor(token_type_ids, word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
# bert encoder
encoder_output = self.encoder(self.cast_compute_type(embedding_output), attention_mask)
sequence_output = self.cast(encoder_output, self.dtype)
# pooler
batch_size = P.Shape()(input_ids)[0]
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.pooler(first_token)
pooled_output = self.cast(pooled_output, self.dtype)
return sequence_output, pooled_output
class AlbertMLMHead(nn.Cell):
"""
Get masked lm output.
Args:
config (AlbertConfig): The config of BertModel.
Returns:
Tensor, masked lm output.
"""
def __init__(self, config):
super(AlbertMLMHead, self).__init__()
self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_type)
self.dense = nn.Dense(
config.hidden_size,
config.embedding_size,
weight_init=TruncatedNormal(config.initializer_range),
activation=config.hidden_act
).to_float(config.compute_type)
self.decoder = nn.Dense(
config.embedding_size,
config.vocab_size,
weight_init=TruncatedNormal(config.initializer_range),
).to_float(config.compute_type)
def construct(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.layernorm(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class AlbertModelCLS(nn.Cell):
"""
This class is responsible for classification task evaluation,
i.e. mnli(num_labels=3), qnli(num_labels=2), qqp(num_labels=2).
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
"""
def __init__(self, config):
super(AlbertModelCLS, self).__init__()
self.albert = AlbertModel(config)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.classifier = nn.Dense(config.hidden_size, config.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.relu = nn.ReLU()
self.is_training = config.is_training
if self.is_training:
self.dropout = nn.Dropout(1 - config.classifier_dropout_prob)
def construct(self, input_ids, input_mask, token_type_id):
"""classification albert model"""
_, pooled_output = self.albert(input_ids, token_type_id, input_mask)
# pooled_output = self.relu(pooled_output)
if self.is_training:
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
logits = self.cast(logits, self.dtype)
return logits
class AlbertModelForAD(nn.Cell):
"""albert model for ad"""
def __init__(self, config):
super(AlbertModelForAD, self).__init__()
# main model
self.albert = AlbertModel(config)
# classifier head
self.cast = P.Cast()
self.dtype = config.dtype
self.classifier = nn.Dense(config.hidden_size, config.num_labels,
weight_init=TruncatedNormal(config.initializer_range),
has_bias=True).to_float(config.compute_type)
self.is_training = config.is_training
if self.is_training:
self.dropout = nn.Dropout(1 - config.classifier_dropout_prob)
# masked language model head
self.predictions = AlbertMLMHead(config)
def construct(self, input_ids, input_mask, token_type_id):
"""albert model for ad"""
sequence_output, pooled_output = self.albert(input_ids, token_type_id, input_mask)
if self.is_training:
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
logits = self.cast(logits, self.dtype)
prediction_scores = self.predictions(sequence_output)
prediction_scores = self.cast(prediction_scores, self.dtype)
return prediction_scores, logits
class AlbertModelMLM(nn.Cell):
"""albert model for mlm"""
def __init__(self, config):
super(AlbertModelMLM, self).__init__()
self.cast = P.Cast()
self.dtype = config.dtype
# main model
self.albert = AlbertModel(config)
# masked language model head
self.predictions = AlbertMLMHead(config)
def construct(self, input_ids, input_mask, token_type_id):
"""albert model for mlm"""
sequence_output, _ = self.albert(input_ids, token_type_id, input_mask)
prediction_scores = self.predictions(sequence_output)
prediction_scores = self.cast(prediction_scores, self.dtype)
return prediction_scores

View File

@ -0,0 +1,558 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tokenization classes."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import logging
import os
import unicodedata
from io import open
logger = logging.getLogger(__name__)
def load_vocab(vocab_file, vocab_map_ids_path=None):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
if vocab_map_ids_path is not None:
vocab_new_ids = list()
with open(vocab_map_ids_path, "r", encoding="utf-8") as vocab_new_ids_reader:
while True:
index = vocab_new_ids_reader.readline()
if not index:
break
index = index.strip()
vocab_new_ids.append(int(index))
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
vocab[token] = vocab_new_ids[index]
index += 1
return vocab
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class BertTokenizer:
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, basic_only=False,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BertTokenizer.
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input
Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
"""
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
self.basic_only = basic_only
def tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text):
if self.basic_only:
split_tokens.append(token)
else:
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(self.vocab.get(token, self.vocab['[UNK]']))
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, 'vocab.txt')
else:
raise FileNotFoundError
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if 'txt' in pretrained_model_name_or_path:
resolved_vocab_file = pretrained_model_name_or_path
else:
resolved_vocab_file = os.path.join(pretrained_model_name_or_path, 'vocab.txt')
max_len = 512
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
class BasicTokenizer:
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self,
do_lower_case=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
self.never_split = never_split
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in self.never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
@staticmethod
def _run_strip_accents(text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
@staticmethod
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(0x4E00 <= cp <= 0x9FFF) or
(0x3400 <= cp <= 0x4DBF) or
(0x20000 <= cp <= 0x2A6DF) or
(0x2A700 <= cp <= 0x2B73F) or
(0x2B740 <= cp <= 0x2B81F) or
(0x2B820 <= cp <= 0x2CEAF) or
(0xF900 <= cp <= 0xFAFF) or
(0x2F800 <= cp <= 0x2FA1F)
):
return True
return False
@staticmethod
def _clean_text(text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer:
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char in (" ", "\t", "\n", "\r"):
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char in ("\t", "\n", "\r"):
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
class CustomizedBasicTokenizer(BasicTokenizer):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), keywords=None):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
super().__init__(do_lower_case, never_split)
self.do_lower_case = do_lower_case
self.never_split = never_split
self.keywords = keywords
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
if self.keywords is not None:
new_orig_tokens = []
lengths = [len(_) for _ in self.keywords]
max_length = max(lengths)
orig_tokens_len = len(orig_tokens)
i = 0
while i < orig_tokens_len:
has_add = False
for length in range(max_length, 0, -1):
if i + length > orig_tokens_len:
continue
add_token = ''.join(orig_tokens[i:i+length])
if add_token in self.keywords:
new_orig_tokens.append(add_token)
i += length
has_add = True
break
if not has_add:
new_orig_tokens.append(orig_tokens[i])
i += 1
else:
new_orig_tokens = orig_tokens
split_tokens = []
for token in new_orig_tokens:
if self.do_lower_case and token not in self.never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
class CustomizedTokenizer(BertTokenizer):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), keywords=None):
"""Constructs a CustomizedTokenizer.
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input
Only has an effect when do_wordpiece_only=False
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
"""
super().__init__(vocab_file, do_lower_case, max_len, never_split)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = CustomizedBasicTokenizer(do_lower_case=do_lower_case, never_split=never_split,
keywords=keywords)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text):
split_tokens = []
basic_tokens = self.basic_tokenizer.tokenize(text)
for token in basic_tokens:
wordpiece_tokens = self.wordpiece_tokenizer.tokenize(token)
for sub_token in wordpiece_tokens:
split_tokens.append(sub_token)
return split_tokens
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
resolved_vocab_file = os.path.join(pretrained_model_name_or_path, 'customized_vocab.txt')
max_len = 512
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
class CustomizedTextBasicTokenizer(BasicTokenizer):
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in self.never_split:
token = token.lower()
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp) or len(char.encode('utf-8')) > 1:
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
class CustomizedTextTokenizer(BertTokenizer):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"),
vocab_map_ids_path=None):
"""Constructs a CustomizedTokenizer.
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input
Only has an effect when do_wordpiece_only=False
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
"""
super().__init__(vocab_file, do_lower_case, max_len, never_split)
self.vocab = load_vocab(vocab_file, vocab_map_ids_path)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = CustomizedTextBasicTokenizer(do_lower_case=do_lower_case, never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)

View File

@ -0,0 +1,126 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import copy
from mindspore.common.initializer import initializer
def average_weights(para_list):
global_parameter = {}
length = len(para_list)
for para in para_list:
for name in para:
if name in global_parameter:
global_parameter[name] += para[name] / length
else:
global_parameter[name] = para[name] / length
return global_parameter
def save_params(network, param_dict=None):
if param_dict is None:
return {param.name: copy.deepcopy(param) for param in network.trainable_params()
if 'learning_rate' not in param.name and 'adam' not in param.name}
for param in network.trainable_params():
if param.name in param_dict:
param_dict[param.name] = copy.deepcopy(param)
return None
def restore_params(network, param_dict, init_adam=True):
for param in network.trainable_params():
if 'learning_rate' in param.name:
continue
param.init_data()
if init_adam:
if 'adam' in param.name:
param.set_data(initializer('zeros', shape=param.shape, dtype=param.dtype))
elif param.name in param_dict:
param.set_data(param_dict[param.name])
else:
if param.name in param_dict:
param.set_data(param_dict[param.name])
def get_worker_upload_list():
return [
'albert.encoder.embedding_hidden_mapping_in.weight',
'albert.encoder.embedding_hidden_mapping_in.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta',
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma',
'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta',
'albert.pooler.weight',
'albert.pooler.bias',
'classifier.weight',
'classifier.bias']
def upload_to_server(network, worker_upload_list):
for param in network.trainable_params():
if param.name in worker_upload_list:
param.set_param_fl(push_to_server=True)
def get_worker_download_list():
return [
'albert.encoder.embedding_hidden_mapping_in.weight',
'albert.encoder.embedding_hidden_mapping_in.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma',
'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta',
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight',
'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias',
'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma',
'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta'
]
def download_from_server(network, worker_download_list):
for param in network.trainable_params():
if param.name in worker_download_list:
param.set_param_fl(pull_from_server=True)
def get_freeze_list():
return [
'albert.word_embeddings.embedding_table',
'albert.embedding_postprocessor.embedding_table',
'albert.embedding_postprocessor.full_position_embeddings',
'albert.embedding_postprocessor.layernorm.gamma',
'albert.embedding_postprocessor.layernorm.beta'
]
def freeze(network, freeze_list):
for param in network.trainable_params():
if param.name in freeze_list:
param.requires_grad = False