forked from mindspore-Ecosystem/mindspore
!10425 reformat code and remove useless code in textrcnn
From: @chenmai1102 Reviewed-by: @oacjiewen,@guoqi1024 Signed-off-by: @guoqi1024
This commit is contained in:
commit
a525846718
|
@ -16,6 +16,7 @@
|
|||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser(description='textrcnn')
|
||||
parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.')
|
||||
parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src')
|
||||
|
@ -24,18 +25,18 @@ parser.add_argument('--out_dir', type=str, help='the target dataset directory.',
|
|||
args = parser.parse_args()
|
||||
np.random.seed(2)
|
||||
|
||||
|
||||
def dataset_split(label):
|
||||
"""dataset_split api"""
|
||||
# label can be 'pos' or 'neg'
|
||||
pos_samples = []
|
||||
pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label)
|
||||
pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity." + label)
|
||||
pfhand = open(pos_file, encoding='utf-8')
|
||||
pos_samples += pfhand.readlines()
|
||||
pfhand.close()
|
||||
perm = np.random.permutation(len(pos_samples))
|
||||
# print(perm[0:int(len(pos_samples)*0.8)])
|
||||
perm_train = perm[0:int(len(pos_samples)*0.9)]
|
||||
perm_test = perm[int(len(pos_samples)*0.9):]
|
||||
perm_train = perm[0:int(len(pos_samples) * 0.9)]
|
||||
perm_test = perm[int(len(pos_samples) * 0.9):]
|
||||
pos_samples_train = []
|
||||
pos_samples_test = []
|
||||
for pt in perm_train:
|
||||
|
@ -51,10 +52,7 @@ def dataset_split(label):
|
|||
f.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.task == "dataset_split":
|
||||
dataset_split('pos')
|
||||
dataset_split('neg')
|
||||
|
||||
# search(args.q)
|
||||
|
|
|
@ -32,7 +32,6 @@ from src.textrcnn import textrcnn
|
|||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='textrcnn')
|
||||
parser.add_argument('--ckpt_path', type=str)
|
||||
|
@ -46,8 +45,8 @@ if __name__ == '__main__':
|
|||
context.set_context(device_id=device_id)
|
||||
|
||||
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
|
||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
|
||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
loss_cb = LossMonitor()
|
||||
|
|
|
@ -74,7 +74,7 @@ DEVICE_ID=7 python train.py
|
|||
bash scripts/run_train.sh
|
||||
|
||||
# run evaluating
|
||||
DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-10_149.ckpt
|
||||
DEVICE_ID=7 python eval.py --ckpt_path {checkpoint path}
|
||||
# or you can use the shell script to evaluate in background
|
||||
bash scripts/run_eval.sh
|
||||
```
|
||||
|
|
|
@ -21,6 +21,7 @@ import numpy as np
|
|||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
# preprocess part
|
||||
def encode_samples(tokenized_samples, word_to_idx):
|
||||
""" encode word to index """
|
||||
|
@ -78,7 +79,8 @@ def collect_weight(glove_path, vocab, word_to_idx, embed_size):
|
|||
# wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
|
||||
# binary=False, encoding='utf-8')
|
||||
wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
|
||||
'GoogleNews-vectors-negative300.bin'), binary=True)
|
||||
'GoogleNews-vectors-negative300.bin'),
|
||||
binary=True)
|
||||
weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
|
||||
|
||||
idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
|
||||
|
@ -140,8 +142,8 @@ def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
|
|||
preprocess(data_path, glove_path, embed_size)
|
||||
np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
|
||||
|
||||
print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\
|
||||
weight_np.shape, "type:", train_labels.dtype)
|
||||
print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",
|
||||
weight_np.shape, "type:", train_labels.dtype)
|
||||
# write mindrecord
|
||||
schema_json = {"id": {"type": "int32"},
|
||||
"label": {"type": "int32"},
|
||||
|
|
|
@ -22,8 +22,10 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class textrcnn(nn.Cell):
|
||||
"""class textrcnn"""
|
||||
|
||||
def __init__(self, weight, vocab_size, cell, batch_size):
|
||||
super(textrcnn, self).__init__()
|
||||
self.num_hiddens = 512
|
||||
|
@ -89,7 +91,6 @@ class textrcnn(nn.Cell):
|
|||
self.tanh = P.Tanh()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.slice = P.Slice()
|
||||
# self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0)
|
||||
|
||||
def construct(self, x):
|
||||
"""class construction"""
|
||||
|
@ -100,34 +101,34 @@ class textrcnn(nn.Cell):
|
|||
if self.cell == "vanilla":
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
|
||||
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
|
||||
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
|
||||
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
|
||||
|
||||
for i in range(1, F.shape(x)[0]):
|
||||
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
|
||||
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
|
||||
h1_after_expand_fw = self.expand_dims(h1_fw, 0)
|
||||
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
|
||||
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
||||
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
|
||||
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
|
||||
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
|
||||
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
|
||||
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
|
||||
|
||||
for i in range(F.shape(x)[0] - 2, -1, -1):
|
||||
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
|
||||
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
|
||||
h1_after_expand_bw = self.expand_dims(h1_bw, 0)
|
||||
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
|
||||
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
|
||||
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
|
||||
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
if self.cell == "gru":
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h_fw = self.cast(self.h1, mstype.float16)
|
||||
|
||||
|
@ -148,7 +149,7 @@ class textrcnn(nn.Cell):
|
|||
output_fw = self.concat((output_fw, h_after_expand_fw))
|
||||
output_fw = self.cast(output_fw, mstype.float16)
|
||||
|
||||
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
|
||||
h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
|
||||
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
|
||||
|
@ -168,29 +169,29 @@ class textrcnn(nn.Cell):
|
|||
if self.cell == 'lstm':
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h1_fw_init = self.h1 # bs, num_hidden
|
||||
c1_fw_init = self.c1 # bs, num_hidden
|
||||
h1_fw_init = self.h1 # bs, num_hidden
|
||||
c1_fw_init = self.c1 # bs, num_hidden
|
||||
|
||||
_, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
|
||||
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
h1_bw_init = self.h1 # bs, num_hidden
|
||||
c1_bw_init = self.c1 # bs, num_hidden
|
||||
h1_bw_init = self.h1 # bs, num_hidden
|
||||
c1_bw_init = self.c1 # bs, num_hidden
|
||||
_, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
|
||||
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
|
||||
|
||||
c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
|
||||
c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
|
||||
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
|
||||
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
|
||||
output = self.cast(output, mstype.float16)
|
||||
|
||||
output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
|
||||
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
|
||||
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
|
||||
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
|
||||
output = self.reduce_max(output, 0) # bs, num_hidden
|
||||
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
|
||||
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
|
||||
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
|
||||
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
|
||||
output = self.reduce_max(output, 0) # bs, num_hidden
|
||||
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
|
||||
return outputs
|
||||
|
|
|
@ -31,7 +31,6 @@ from src.dataset import convert_to_mindrecord
|
|||
from src.textrcnn import textrcnn
|
||||
from src.utils import get_lr
|
||||
|
||||
|
||||
set_seed(2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -56,7 +55,7 @@ if __name__ == '__main__':
|
|||
|
||||
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
|
||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
|
||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
|
||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
|
||||
ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True)
|
||||
|
@ -74,7 +73,7 @@ if __name__ == '__main__':
|
|||
model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3")
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, \
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
|
||||
model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb])
|
||||
|
|
Loading…
Reference in New Issue