131 lines
4.4 KiB
Python
131 lines
4.4 KiB
Python
"""
|
|
Running scripts.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from space.args import parse_args
|
|
from space.args import str2bool
|
|
from space.data.dataset import Dataset
|
|
from space.data.fields.gen_field import BPETextField, MultiWOZBPETextField, CamRestBPETextField, KvretBPETextField
|
|
from space.trainers.gen_trainer import Trainer, MultiWOZTrainer, CamRestTrainer, KvretTrainer
|
|
from space.models.model_base import ModelBase
|
|
from space.models.generator import Generator
|
|
from space.utils.eval import MultiWOZEvaluator, CamRestEvaluator, KvretEvaluator
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--do_train", type=str2bool, default=False,
|
|
help="Whether to run trainning in plato setting.")
|
|
parser.add_argument("--do_test", type=str2bool, default=False,
|
|
help="Whether to run evaluation on the test dataset.")
|
|
parser.add_argument("--do_infer", type=str2bool, default=False,
|
|
help="Whether to run inference on the test dataset.")
|
|
parser.add_argument("--num_infer_batches", type=int, default=None,
|
|
help="The number of batches need to infer.\n"
|
|
"Stay 'None': infer on entrie test dataset.")
|
|
parser.add_argument("--hparams_file", type=str, default=None,
|
|
help="Loading hparams setting from file(.json format).")
|
|
BPETextField.add_cmdline_argument(parser)
|
|
Dataset.add_cmdline_argument(parser)
|
|
Trainer.add_cmdline_argument(parser)
|
|
ModelBase.add_cmdline_argument(parser)
|
|
Generator.add_cmdline_argument(parser)
|
|
|
|
hparams = parse_args(parser)
|
|
hparams.use_gpu = torch.cuda.is_available() and hparams.gpu >= 1
|
|
|
|
print(json.dumps(hparams, indent=2))
|
|
|
|
if not os.path.exists(hparams.save_dir):
|
|
os.makedirs(hparams.save_dir)
|
|
hparams.save(os.path.join(hparams.save_dir, "hparams.json"))
|
|
|
|
def to_tensor(array):
|
|
"""
|
|
numpy array -> tensor
|
|
"""
|
|
array = torch.tensor(array)
|
|
return array.cuda() if hparams.use_gpu else array
|
|
|
|
def set_seed(seed):
|
|
""" fix random seed """
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
# set seed
|
|
set_seed(seed=hparams.seed)
|
|
|
|
# set reader and evaluator
|
|
if hparams.data_name == 'camrest':
|
|
bpe = CamRestBPETextField(hparams)
|
|
evaluator = CamRestEvaluator(reader=bpe)
|
|
elif hparams.data_name == 'multiwoz':
|
|
bpe = MultiWOZBPETextField(hparams)
|
|
evaluator = MultiWOZEvaluator(reader=bpe)
|
|
elif hparams.data_name == 'kvret':
|
|
bpe = KvretBPETextField(hparams)
|
|
evaluator = KvretEvaluator(reader=bpe)
|
|
else:
|
|
raise NotImplementedError("Other dataset's reader and evaluator to be implemented !")
|
|
|
|
hparams.Model.num_token_embeddings = bpe.vocab_size
|
|
hparams.Model.num_turn_embeddings = bpe.max_ctx_turn + 1
|
|
|
|
# set data and data status
|
|
if hparams.do_train:
|
|
train_data = bpe.get_batches('train')
|
|
dev_data = bpe.get_batches('dev')
|
|
else:
|
|
train_data, dev_data, = [], []
|
|
|
|
# set generator
|
|
generator = Generator.create(hparams, reader=bpe)
|
|
|
|
# construct model
|
|
model = ModelBase.create(hparams, reader=bpe, generator=generator)
|
|
print("Total number of parameters in networks is {}".format(sum(x.numel() for x in model.parameters())))
|
|
|
|
# multi-gpu
|
|
if hparams.gpu > 1 and torch.cuda.device_count() > 1:
|
|
model = torch.nn.DataParallel(model)
|
|
|
|
# construct trainer
|
|
if hparams.data_name == 'camrest':
|
|
trainer = CamRestTrainer(model, to_tensor, hparams, reader=bpe, evaluator=evaluator)
|
|
elif hparams.data_name == 'multiwoz':
|
|
trainer = MultiWOZTrainer(model, to_tensor, hparams, reader=bpe, evaluator=evaluator)
|
|
elif hparams.data_name == 'kvret':
|
|
trainer = KvretTrainer(model, to_tensor, hparams, reader=bpe, evaluator=evaluator)
|
|
else:
|
|
raise NotImplementedError("Other dataset's trainer to be implemented !")
|
|
|
|
# set optimizer and lr_scheduler
|
|
if hparams.do_train:
|
|
trainer.set_optimizers()
|
|
|
|
# load model, optimizer and lr_scheduler
|
|
trainer.load()
|
|
|
|
if hparams.do_train:
|
|
# training process
|
|
trainer.train(train_data=train_data, dev_data=dev_data)
|
|
|
|
if hparams.do_infer:
|
|
# infer process
|
|
trainer.infer(data_type='test')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|