From 7d492023119de63621ac1fe44371bd44dddd5802 Mon Sep 17 00:00:00 2001 From: w00517672 Date: Thu, 8 Jul 2021 21:45:40 +0800 Subject: [PATCH] fix error in albert script --- tests/st/fl/albert/cloud_train.py | 67 ++++++++----------------------- 1 file changed, 16 insertions(+), 51 deletions(-) diff --git a/tests/st/fl/albert/cloud_train.py b/tests/st/fl/albert/cloud_train.py index 8437694c555..f371c32173b 100644 --- a/tests/st/fl/albert/cloud_train.py +++ b/tests/st/fl/albert/cloud_train.py @@ -17,12 +17,11 @@ import argparse import os import sys from time import time -from mindspore import context +import numpy as np +from mindspore import context, Tensor from mindspore.train.serialization import save_checkpoint, load_checkpoint from src.adam import AdamWeightDecayOp as AdamWeightDecay -from src.tokenization import CustomizedTextTokenizer from src.config import train_cfg, server_net_cfg -from src.dataset import load_dataset from src.utils import restore_params from src.model import AlbertModelCLS from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell @@ -71,12 +70,8 @@ def server_train(args): start = time() os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id - tokenizer_dir = args.tokenizer_dir - server_data_path = args.server_data_path model_path = args.model_path output_dir = args.output_dir - vocab_map_ids_path = args.vocab_map_ids_path - logging_step = args.logging_step device_target = args.device_target server_mode = args.server_mode @@ -130,12 +125,6 @@ def server_train(args): if not os.path.exists(output_dir): os.makedirs(output_dir) - # construct tokenizer - tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path) - print('Tokenizer construction is done! Time cost: {}'.format(time() - start)) - sys.stdout.flush() - start = time() - # mindspore context context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=True) context.set_fl_context(**fl_ctx) @@ -153,15 +142,12 @@ def server_train(args): start = time() # train prepare - global_step = 0 param_dict = load_checkpoint(model_path) if 'learning_rate' in param_dict: del param_dict['learning_rate'] # server optimizer - server_params = [_ for _ in network_with_cls_loss.trainable_params() - if 'word_embeddings' not in _.name - and 'postprocessor' not in _.name] + server_params = [_ for _ in network_with_cls_loss.trainable_params()] server_decay_params = list( filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params) ) @@ -184,41 +170,20 @@ def server_train(args): sys.stdout.flush() start = time() - # server load data - server_train_dataset, _ = load_dataset( - server_data_path, server_net_cfg.seq_length, tokenizer, train_cfg.batch_size, - label_list=None, - do_shuffle=True, - drop_remainder=True, - output_dir=None, - cyclic_trunc=train_cfg.server_cfg.cyclic_trunc - ) - print('Server data loading is done! Time cost: {}'.format(time() - start)) - start = time() - # train process - for global_epoch in range(train_cfg.max_global_epoch): - for server_local_epoch in range(train_cfg.server_cfg.max_local_epoch): - for server_step, server_batch in enumerate(server_train_dataset.create_tuple_iterator()): - input_ids, attention_mask, token_type_ids, label_ids, _ = server_batch - model_start_time = time() - cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids) - time_cost = time() - model_start_time - if global_step % logging_step == 0: - print_text = 'server: ' - print_text += 'global_epoch {}/{} '.format(global_epoch, train_cfg.max_global_epoch) - print_text += 'local_epoch {}/{} '.format(server_local_epoch, train_cfg.server_cfg.max_local_epoch) - print_text += 'local_step {}/{} '.format(server_step, server_train_dataset.get_dataset_size()) - print_text += 'global_step {} cls_loss {} time_cost {}'.format(global_step, cls_loss, time_cost) - print(print_text) - sys.stdout.flush() - global_step += 1 - del input_ids, attention_mask, token_type_ids, label_ids, _, cls_loss - output_path = os.path.join( - output_dir, - str(global_epoch*train_cfg.server_cfg.max_local_epoch+server_local_epoch)+'.ckpt' - ) - save_checkpoint(server_network_train_cell.network, output_path) + for _ in range(1): + input_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32)) + attention_mask = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32)) + token_type_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32)) + label_ids = Tensor(np.zeros((train_cfg.batch_size,), np.int32)) + model_start_time = time() + cls_loss = server_network_train_cell(input_ids, attention_mask, token_type_ids, label_ids) + time_cost = time() - model_start_time + print('server: cls_loss {} time_cost {}'.format(cls_loss, time_cost)) + sys.stdout.flush() + del input_ids, attention_mask, token_type_ids, label_ids, cls_loss + output_path = os.path.join(output_dir, 'final.ckpt') + save_checkpoint(server_network_train_cell.network, output_path) print('Training process is done! Time cost: {}'.format(time() - start))