forked from mindspore-Ecosystem/mindspore
fix error in albert script
This commit is contained in:
parent
de4b9a94fc
commit
7d49202311
|
@ -17,12 +17,11 @@ import argparse
|
|||
import os
|
||||
import sys
|
||||
from time import time
|
||||
from mindspore import context
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
from src.adam import AdamWeightDecayOp as AdamWeightDecay
|
||||
from src.tokenization import CustomizedTextTokenizer
|
||||
from src.config import train_cfg, server_net_cfg
|
||||
from src.dataset import load_dataset
|
||||
from src.utils import restore_params
|
||||
from src.model import AlbertModelCLS
|
||||
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
|
||||
|
@ -71,12 +70,8 @@ def server_train(args):
|
|||
start = time()
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
|
||||
tokenizer_dir = args.tokenizer_dir
|
||||
server_data_path = args.server_data_path
|
||||
model_path = args.model_path
|
||||
output_dir = args.output_dir
|
||||
vocab_map_ids_path = args.vocab_map_ids_path
|
||||
logging_step = args.logging_step
|
||||
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
|
@ -130,12 +125,6 @@ def server_train(args):
|
|||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# construct tokenizer
|
||||
tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
|
||||
print('Tokenizer construction is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# mindspore context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True)
|
||||
context.set_fl_context(**fl_ctx)
|
||||
|
@ -153,15 +142,12 @@ def server_train(args):
|
|||
start = time()
|
||||
|
||||
# train prepare
|
||||
global_step = 0
|
||||
param_dict = load_checkpoint(model_path)
|
||||
if 'learning_rate' in param_dict:
|
||||
del param_dict['learning_rate']
|
||||
|
||||
# server optimizer
|
||||
server_params = [_ for _ in network_with_cls_loss.trainable_params()
|
||||
if 'word_embeddings' not in _.name
|
||||
and 'postprocessor' not in _.name]
|
||||
server_params = [_ for _ in network_with_cls_loss.trainable_params()]
|
||||
server_decay_params = list(
|
||||
filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params)
|
||||
)
|
||||
|
@ -184,40 +170,19 @@ def server_train(args):
|
|||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# server load data
|
||||
server_train_dataset, _ = load_dataset(
|
||||
server_data_path, server_net_cfg.seq_length, tokenizer, train_cfg.batch_size,
|
||||
label_list=None,
|
||||
do_shuffle=True,
|
||||
drop_remainder=True,
|
||||
output_dir=None,
|
||||
cyclic_trunc=train_cfg.server_cfg.cyclic_trunc
|
||||
)
|
||||
print('Server data loading is done! Time cost: {}'.format(time() - start))
|
||||
start = time()
|
||||
|
||||
# train process
|
||||
for global_epoch in range(train_cfg.max_global_epoch):
|
||||
for server_local_epoch in range(train_cfg.server_cfg.max_local_epoch):
|
||||
for server_step, server_batch in enumerate(server_train_dataset.create_tuple_iterator()):
|
||||
input_ids, attention_mask, token_type_ids, label_ids, _ = server_batch
|
||||
for _ in range(1):
|
||||
input_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
|
||||
attention_mask = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
|
||||
token_type_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
|
||||
label_ids = Tensor(np.zeros((train_cfg.batch_size,), np.int32))
|
||||
model_start_time = time()
|
||||
cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids)
|
||||
time_cost = time() - model_start_time
|
||||
if global_step % logging_step == 0:
|
||||
print_text = 'server: '
|
||||
print_text += 'global_epoch {}/{} '.format(global_epoch, train_cfg.max_global_epoch)
|
||||
print_text += 'local_epoch {}/{} '.format(server_local_epoch, train_cfg.server_cfg.max_local_epoch)
|
||||
print_text += 'local_step {}/{} '.format(server_step, server_train_dataset.get_dataset_size())
|
||||
print_text += 'global_step {} cls_loss {} time_cost {}'.format(global_step, cls_loss, time_cost)
|
||||
print(print_text)
|
||||
print('server: cls_loss {} time_cost {}'.format(cls_loss, time_cost))
|
||||
sys.stdout.flush()
|
||||
global_step += 1
|
||||
del input_ids, attention_mask, token_type_ids, label_ids, _, cls_loss
|
||||
output_path = os.path.join(
|
||||
output_dir,
|
||||
str(global_epoch*train_cfg.server_cfg.max_local_epoch+server_local_epoch)+'.ckpt'
|
||||
)
|
||||
del input_ids, attention_mask, token_type_ids, label_ids, cls_loss
|
||||
output_path = os.path.join(output_dir, 'final.ckpt')
|
||||
save_checkpoint(server_network_train_cell.network, output_path)
|
||||
|
||||
print('Training process is done! Time cost: {}'.format(time() - start))
|
||||
|
|
Loading…
Reference in New Issue