fix error in albert script

This commit is contained in:
w00517672 2021-07-08 21:45:40 +08:00
parent de4b9a94fc
commit 7d49202311
1 changed files with 16 additions and 51 deletions

View File

@ -17,12 +17,11 @@ import argparse
import os import os
import sys import sys
from time import time from time import time
from mindspore import context import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import save_checkpoint, load_checkpoint from mindspore.train.serialization import save_checkpoint, load_checkpoint
from src.adam import AdamWeightDecayOp as AdamWeightDecay from src.adam import AdamWeightDecayOp as AdamWeightDecay
from src.tokenization import CustomizedTextTokenizer
from src.config import train_cfg, server_net_cfg from src.config import train_cfg, server_net_cfg
from src.dataset import load_dataset
from src.utils import restore_params from src.utils import restore_params
from src.model import AlbertModelCLS from src.model import AlbertModelCLS
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
@ -71,12 +70,8 @@ def server_train(args):
start = time() start = time()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
tokenizer_dir = args.tokenizer_dir
server_data_path = args.server_data_path
model_path = args.model_path model_path = args.model_path
output_dir = args.output_dir output_dir = args.output_dir
vocab_map_ids_path = args.vocab_map_ids_path
logging_step = args.logging_step
device_target = args.device_target device_target = args.device_target
server_mode = args.server_mode server_mode = args.server_mode
@ -130,12 +125,6 @@ def server_train(args):
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
# construct tokenizer
tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
print('Tokenizer construction is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# mindspore context # mindspore context
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True)
context.set_fl_context(**fl_ctx) context.set_fl_context(**fl_ctx)
@ -153,15 +142,12 @@ def server_train(args):
start = time() start = time()
# train prepare # train prepare
global_step = 0
param_dict = load_checkpoint(model_path) param_dict = load_checkpoint(model_path)
if 'learning_rate' in param_dict: if 'learning_rate' in param_dict:
del param_dict['learning_rate'] del param_dict['learning_rate']
# server optimizer # server optimizer
server_params = [_ for _ in network_with_cls_loss.trainable_params() server_params = [_ for _ in network_with_cls_loss.trainable_params()]
if 'word_embeddings' not in _.name
and 'postprocessor' not in _.name]
server_decay_params = list( server_decay_params = list(
filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params) filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params)
) )
@ -184,41 +170,20 @@ def server_train(args):
sys.stdout.flush() sys.stdout.flush()
start = time() start = time()
# server load data
server_train_dataset, _ = load_dataset(
server_data_path, server_net_cfg.seq_length, tokenizer, train_cfg.batch_size,
label_list=None,
do_shuffle=True,
drop_remainder=True,
output_dir=None,
cyclic_trunc=train_cfg.server_cfg.cyclic_trunc
)
print('Server data loading is done! Time cost: {}'.format(time() - start))
start = time()
# train process # train process
for global_epoch in range(train_cfg.max_global_epoch): for _ in range(1):
for server_local_epoch in range(train_cfg.server_cfg.max_local_epoch): input_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
for server_step, server_batch in enumerate(server_train_dataset.create_tuple_iterator()): attention_mask = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
input_ids, attention_mask, token_type_ids, label_ids, _ = server_batch token_type_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
model_start_time = time() label_ids = Tensor(np.zeros((train_cfg.batch_size,), np.int32))
cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids) model_start_time = time()
time_cost = time() - model_start_time cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids)
if global_step % logging_step == 0: time_cost = time() - model_start_time
print_text = 'server: ' print('server: cls_loss {} time_cost {}'.format(cls_loss, time_cost))
print_text += 'global_epoch {}/{} '.format(global_epoch, train_cfg.max_global_epoch) sys.stdout.flush()
print_text += 'local_epoch {}/{} '.format(server_local_epoch, train_cfg.server_cfg.max_local_epoch) del input_ids, attention_mask, token_type_ids, label_ids, cls_loss
print_text += 'local_step {}/{} '.format(server_step, server_train_dataset.get_dataset_size()) output_path = os.path.join(output_dir, 'final.ckpt')
print_text += 'global_step {} cls_loss {} time_cost {}'.format(global_step, cls_loss, time_cost) save_checkpoint(server_network_train_cell.network, output_path)
print(print_text)
sys.stdout.flush()
global_step += 1
del input_ids, attention_mask, token_type_ids, label_ids, _, cls_loss
output_path = os.path.join(
output_dir,
str(global_epoch*train_cfg.server_cfg.max_local_epoch+server_local_epoch)+'.ckpt'
)
save_checkpoint(server_network_train_cell.network, output_path)
print('Training process is done! Time cost: {}'.format(time() - start)) print('Training process is done! Time cost: {}'.format(time() - start))