forked from mindspore-Ecosystem/mindspore
fix error in albert script
This commit is contained in:
parent
de4b9a94fc
commit
7d49202311
|
@ -17,12 +17,11 @@ import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from time import time
|
from time import time
|
||||||
from mindspore import context
|
import numpy as np
|
||||||
|
from mindspore import context, Tensor
|
||||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||||
from src.adam import AdamWeightDecayOp as AdamWeightDecay
|
from src.adam import AdamWeightDecayOp as AdamWeightDecay
|
||||||
from src.tokenization import CustomizedTextTokenizer
|
|
||||||
from src.config import train_cfg, server_net_cfg
|
from src.config import train_cfg, server_net_cfg
|
||||||
from src.dataset import load_dataset
|
|
||||||
from src.utils import restore_params
|
from src.utils import restore_params
|
||||||
from src.model import AlbertModelCLS
|
from src.model import AlbertModelCLS
|
||||||
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
|
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
|
||||||
|
@ -71,12 +70,8 @@ def server_train(args):
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
|
||||||
tokenizer_dir = args.tokenizer_dir
|
|
||||||
server_data_path = args.server_data_path
|
|
||||||
model_path = args.model_path
|
model_path = args.model_path
|
||||||
output_dir = args.output_dir
|
output_dir = args.output_dir
|
||||||
vocab_map_ids_path = args.vocab_map_ids_path
|
|
||||||
logging_step = args.logging_step
|
|
||||||
|
|
||||||
device_target = args.device_target
|
device_target = args.device_target
|
||||||
server_mode = args.server_mode
|
server_mode = args.server_mode
|
||||||
|
@ -130,12 +125,6 @@ def server_train(args):
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
# construct tokenizer
|
|
||||||
tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
|
|
||||||
print('Tokenizer construction is done! Time cost: {}'.format(time() - start))
|
|
||||||
sys.stdout.flush()
|
|
||||||
start = time()
|
|
||||||
|
|
||||||
# mindspore context
|
# mindspore context
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True)
|
||||||
context.set_fl_context(**fl_ctx)
|
context.set_fl_context(**fl_ctx)
|
||||||
|
@ -153,15 +142,12 @@ def server_train(args):
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
# train prepare
|
# train prepare
|
||||||
global_step = 0
|
|
||||||
param_dict = load_checkpoint(model_path)
|
param_dict = load_checkpoint(model_path)
|
||||||
if 'learning_rate' in param_dict:
|
if 'learning_rate' in param_dict:
|
||||||
del param_dict['learning_rate']
|
del param_dict['learning_rate']
|
||||||
|
|
||||||
# server optimizer
|
# server optimizer
|
||||||
server_params = [_ for _ in network_with_cls_loss.trainable_params()
|
server_params = [_ for _ in network_with_cls_loss.trainable_params()]
|
||||||
if 'word_embeddings' not in _.name
|
|
||||||
and 'postprocessor' not in _.name]
|
|
||||||
server_decay_params = list(
|
server_decay_params = list(
|
||||||
filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params)
|
filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params)
|
||||||
)
|
)
|
||||||
|
@ -184,40 +170,19 @@ def server_train(args):
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
# server load data
|
|
||||||
server_train_dataset, _ = load_dataset(
|
|
||||||
server_data_path, server_net_cfg.seq_length, tokenizer, train_cfg.batch_size,
|
|
||||||
label_list=None,
|
|
||||||
do_shuffle=True,
|
|
||||||
drop_remainder=True,
|
|
||||||
output_dir=None,
|
|
||||||
cyclic_trunc=train_cfg.server_cfg.cyclic_trunc
|
|
||||||
)
|
|
||||||
print('Server data loading is done! Time cost: {}'.format(time() - start))
|
|
||||||
start = time()
|
|
||||||
|
|
||||||
# train process
|
# train process
|
||||||
for global_epoch in range(train_cfg.max_global_epoch):
|
for _ in range(1):
|
||||||
for server_local_epoch in range(train_cfg.server_cfg.max_local_epoch):
|
input_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
|
||||||
for server_step, server_batch in enumerate(server_train_dataset.create_tuple_iterator()):
|
attention_mask = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
|
||||||
input_ids, attention_mask, token_type_ids, label_ids, _ = server_batch
|
token_type_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
|
||||||
|
label_ids = Tensor(np.zeros((train_cfg.batch_size,), np.int32))
|
||||||
model_start_time = time()
|
model_start_time = time()
|
||||||
cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids)
|
cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids)
|
||||||
time_cost = time() - model_start_time
|
time_cost = time() - model_start_time
|
||||||
if global_step % logging_step == 0:
|
print('server: cls_loss {} time_cost {}'.format(cls_loss, time_cost))
|
||||||
print_text = 'server: '
|
|
||||||
print_text += 'global_epoch {}/{} '.format(global_epoch, train_cfg.max_global_epoch)
|
|
||||||
print_text += 'local_epoch {}/{} '.format(server_local_epoch, train_cfg.server_cfg.max_local_epoch)
|
|
||||||
print_text += 'local_step {}/{} '.format(server_step, server_train_dataset.get_dataset_size())
|
|
||||||
print_text += 'global_step {} cls_loss {} time_cost {}'.format(global_step, cls_loss, time_cost)
|
|
||||||
print(print_text)
|
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
global_step += 1
|
del input_ids, attention_mask, token_type_ids, label_ids, cls_loss
|
||||||
del input_ids, attention_mask, token_type_ids, label_ids, _, cls_loss
|
output_path = os.path.join(output_dir, 'final.ckpt')
|
||||||
output_path = os.path.join(
|
|
||||||
output_dir,
|
|
||||||
str(global_epoch*train_cfg.server_cfg.max_local_epoch+server_local_epoch)+'.ckpt'
|
|
||||||
)
|
|
||||||
save_checkpoint(server_network_train_cell.network, output_path)
|
save_checkpoint(server_network_train_cell.network, output_path)
|
||||||
|
|
||||||
print('Training process is done! Time cost: {}'.format(time() - start))
|
print('Training process is done! Time cost: {}'.format(time() - start))
|
||||||
|
|
Loading…
Reference in New Issue