!10425 reformat code and remove useless code in textrcnn

From: @chenmai1102
Reviewed-by: @oacjiewen,@guoqi1024
Signed-off-by: @guoqi1024
This commit is contained in:
mindspore-ci-bot 2020-12-24 15:38:42 +08:00 committed by Gitee
commit a525846718
6 changed files with 46 additions and 47 deletions

View File

@ -16,6 +16,7 @@
import argparse
import os
import numpy as np
parser = argparse.ArgumentParser(description='textrcnn')
parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.')
parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src')
@ -24,18 +25,18 @@ parser.add_argument('--out_dir', type=str, help='the target dataset directory.',
args = parser.parse_args()
np.random.seed(2)
def dataset_split(label):
"""dataset_split api"""
# label can be 'pos' or 'neg'
pos_samples = []
pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label)
pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity." + label)
pfhand = open(pos_file, encoding='utf-8')
pos_samples += pfhand.readlines()
pfhand.close()
perm = np.random.permutation(len(pos_samples))
# print(perm[0:int(len(pos_samples)*0.8)])
perm_train = perm[0:int(len(pos_samples)*0.9)]
perm_test = perm[int(len(pos_samples)*0.9):]
perm_train = perm[0:int(len(pos_samples) * 0.9)]
perm_test = perm[int(len(pos_samples) * 0.9):]
pos_samples_train = []
pos_samples_test = []
for pt in perm_train:
@ -51,10 +52,7 @@ def dataset_split(label):
f.close()
if __name__ == '__main__':
if args.task == "dataset_split":
dataset_split('pos')
dataset_split('neg')
# search(args.q)

View File

@ -32,7 +32,6 @@ from src.textrcnn import textrcnn
set_seed(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='textrcnn')
parser.add_argument('--ckpt_path', type=str)
@ -46,8 +45,8 @@ if __name__ == '__main__':
context.set_context(device_id=device_id)
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
cell=cfg.cell, batch_size=cfg.batch_size)
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=cfg.cell, batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
loss_cb = LossMonitor()

View File

@ -74,7 +74,7 @@ DEVICE_ID=7 python train.py
bash scripts/run_train.sh
# run evaluating
DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-10_149.ckpt
DEVICE_ID=7 python eval.py --ckpt_path {checkpoint path}
# or you can use the shell script to evaluate in background
bash scripts/run_eval.sh
```

View File

@ -21,6 +21,7 @@ import numpy as np
from mindspore.mindrecord import FileWriter
import mindspore.dataset as ds
# preprocess part
def encode_samples(tokenized_samples, word_to_idx):
""" encode word to index """
@ -78,7 +79,8 @@ def collect_weight(glove_path, vocab, word_to_idx, embed_size):
# wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
# binary=False, encoding='utf-8')
wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
'GoogleNews-vectors-negative300.bin'), binary=True)
'GoogleNews-vectors-negative300.bin'),
binary=True)
weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
@ -140,8 +142,8 @@ def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
preprocess(data_path, glove_path, embed_size)
np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\
weight_np.shape, "type:", train_labels.dtype)
print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",
weight_np.shape, "type:", train_labels.dtype)
# write mindrecord
schema_json = {"id": {"type": "int32"},
"label": {"type": "int32"},

View File

@ -22,8 +22,10 @@ from mindspore.common.parameter import Parameter
from mindspore import Tensor
from mindspore.common import dtype as mstype
class textrcnn(nn.Cell):
"""class textrcnn"""
def __init__(self, weight, vocab_size, cell, batch_size):
super(textrcnn, self).__init__()
self.num_hiddens = 512
@ -89,7 +91,6 @@ class textrcnn(nn.Cell):
self.tanh = P.Tanh()
self.sigmoid = P.Sigmoid()
self.slice = P.Slice()
# self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0)
def construct(self, x):
"""class construction"""
@ -100,34 +101,34 @@ class textrcnn(nn.Cell):
if self.cell == "vanilla":
x = self.embedding(x) # bs, sl, emb_size
x = self.cast(x, mstype.float16)
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
for i in range(1, F.shape(x)[0]):
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
h1_after_expand_fw = self.expand_dims(h1_fw, 0)
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
for i in range(F.shape(x)[0] - 2, -1, -1):
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
h1_after_expand_bw = self.expand_dims(h1_bw, 0)
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
if self.cell == "gru":
x = self.embedding(x) # bs, sl, emb_size
x = self.cast(x, mstype.float16)
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
h_fw = self.cast(self.h1, mstype.float16)
@ -148,7 +149,7 @@ class textrcnn(nn.Cell):
output_fw = self.concat((output_fw, h_after_expand_fw))
output_fw = self.cast(output_fw, mstype.float16)
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
@ -168,29 +169,29 @@ class textrcnn(nn.Cell):
if self.cell == 'lstm':
x = self.embedding(x) # bs, sl, emb_size
x = self.cast(x, mstype.float16)
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
h1_fw_init = self.h1 # bs, num_hidden
c1_fw_init = self.c1 # bs, num_hidden
h1_fw_init = self.h1 # bs, num_hidden
c1_fw_init = self.c1 # bs, num_hidden
_, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
h1_bw_init = self.h1 # bs, num_hidden
c1_bw_init = self.c1 # bs, num_hidden
h1_bw_init = self.h1 # bs, num_hidden
c1_bw_init = self.c1 # bs, num_hidden
_, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
output = self.cast(output, mstype.float16)
output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
output = self.reduce_max(output, 0) # bs, num_hidden
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
output = self.reduce_max(output, 0) # bs, num_hidden
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
return outputs

View File

@ -31,7 +31,6 @@ from src.dataset import convert_to_mindrecord
from src.textrcnn import textrcnn
from src.utils import get_lr
set_seed(2)
if __name__ == '__main__':
@ -56,7 +55,7 @@ if __name__ == '__main__':
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=cfg.cell, batch_size=cfg.batch_size)
ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True)
@ -74,7 +73,7 @@ if __name__ == '__main__':
model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3")
print("============== Starting Training ==============")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, \
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb])