fix error in albert script

This commit is contained in:
w00517672 2021-07-08 21:45:40 +08:00
parent de4b9a94fc
commit 7d49202311
1 changed files with 16 additions and 51 deletions

View File

@ -17,12 +17,11 @@ import argparse
import os
import sys
from time import time
from mindspore import context
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from src.adam import AdamWeightDecayOp as AdamWeightDecay
from src.tokenization import CustomizedTextTokenizer
from src.config import train_cfg, server_net_cfg
from src.dataset import load_dataset
from src.utils import restore_params
from src.model import AlbertModelCLS
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
@ -71,12 +70,8 @@ def server_train(args):
start = time()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
tokenizer_dir = args.tokenizer_dir
server_data_path = args.server_data_path
model_path = args.model_path
output_dir = args.output_dir
vocab_map_ids_path = args.vocab_map_ids_path
logging_step = args.logging_step
device_target = args.device_target
server_mode = args.server_mode
@ -130,12 +125,6 @@ def server_train(args):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# construct tokenizer
tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
print('Tokenizer construction is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()
# mindspore context
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True)
context.set_fl_context(**fl_ctx)
@ -153,15 +142,12 @@ def server_train(args):
start = time()
# train prepare
global_step = 0
param_dict = load_checkpoint(model_path)
if 'learning_rate' in param_dict:
del param_dict['learning_rate']
# server optimizer
server_params = [_ for _ in network_with_cls_loss.trainable_params()
if 'word_embeddings' not in _.name
and 'postprocessor' not in _.name]
server_params = [_ for _ in network_with_cls_loss.trainable_params()]
server_decay_params = list(
filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params)
)
@ -184,40 +170,19 @@ def server_train(args):
sys.stdout.flush()
start = time()
# server load data
server_train_dataset, _ = load_dataset(
server_data_path, server_net_cfg.seq_length, tokenizer, train_cfg.batch_size,
label_list=None,
do_shuffle=True,
drop_remainder=True,
output_dir=None,
cyclic_trunc=train_cfg.server_cfg.cyclic_trunc
)
print('Server data loading is done! Time cost: {}'.format(time() - start))
start = time()
# train process
for global_epoch in range(train_cfg.max_global_epoch):
for server_local_epoch in range(train_cfg.server_cfg.max_local_epoch):
for server_step, server_batch in enumerate(server_train_dataset.create_tuple_iterator()):
input_ids, attention_mask, token_type_ids, label_ids, _ = server_batch
for _ in range(1):
input_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
attention_mask = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
token_type_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
label_ids = Tensor(np.zeros((train_cfg.batch_size,), np.int32))
model_start_time = time()
cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids)
time_cost = time() - model_start_time
if global_step % logging_step == 0:
print_text = 'server: '
print_text += 'global_epoch {}/{} '.format(global_epoch, train_cfg.max_global_epoch)
print_text += 'local_epoch {}/{} '.format(server_local_epoch, train_cfg.server_cfg.max_local_epoch)
print_text += 'local_step {}/{} '.format(server_step, server_train_dataset.get_dataset_size())
print_text += 'global_step {} cls_loss {} time_cost {}'.format(global_step, cls_loss, time_cost)
print(print_text)
print('server: cls_loss {} time_cost {}'.format(cls_loss, time_cost))
sys.stdout.flush()
global_step += 1
del input_ids, attention_mask, token_type_ids, label_ids, _, cls_loss
output_path = os.path.join(
output_dir,
str(global_epoch*train_cfg.server_cfg.max_local_epoch+server_local_epoch)+'.ckpt'
)
del input_ids, attention_mask, token_type_ids, label_ids, cls_loss
output_path = os.path.join(output_dir, 'final.ckpt')
save_checkpoint(server_network_train_cell.network, output_path)
print('Training process is done! Time cost: {}'.format(time() - start))