add textrcnn into model zoo

Signed-off-by: zymaa <317958662@qq.com>

TextRCNN, a model for text classification, which is proposed by the Chinese Academy of Sciences in 2015.
TextRCNN actually combines RNN and CNN, first uses bidirectional RNN to obtain upper semantic and grammatical information of the input text,
and then uses maximum pooling to automatically filter out the most important feature.

Signed-off-by: zymaa <317958662@qq.com>
This commit is contained in:
zymaa 2020-12-02 08:33:55 +00:00
parent 0c7ba7a7fa
commit 16c5013f6e
10 changed files with 795 additions and 0 deletions

View File

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset helpers api"""
import argparse
import os
import numpy as np
parser = argparse.ArgumentParser(description='textrcnn')
parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.')
parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src')
parser.add_argument('--out_dir', type=str, help='the target dataset directory.', default='./data')
args = parser.parse_args()
def dataset_split(label):
"""dataset_split api"""
# label can be 'pos' or 'neg'
pos_samples = []
pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label)
pfhand = open(pos_file, encoding='utf-8')
pos_samples += pfhand.readlines()
pfhand.close()
perm = np.random.permutation(len(pos_samples))
# print(perm[0:int(len(pos_samples)*0.8)])
perm_train = perm[0:int(len(pos_samples)*0.9)]
perm_test = perm[int(len(pos_samples)*0.9):]
pos_samples_train = []
pos_samples_test = []
for pt in perm_train:
pos_samples_train.append(pos_samples[pt])
for pt in perm_test:
pos_samples_test.append(pos_samples[pt])
f = open(os.path.join(args.out_dir, 'train', label), "w")
f.write(''.join(pos_samples_train))
f.close()
f = open(os.path.join(args.out_dir, 'test', label), "w")
f.write(''.join(pos_samples_test))
f.close()
if __name__ == '__main__':
if args.task == "dataset_split":
dataset_split('pos')
dataset_split('neg')
# search(args.q)

View File

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""model evaluation script"""
import os
import argparse
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import LossMonitor
from mindspore.common import set_seed
from src.config import textrcnn_cfg as cfg
from src.dataset import create_dataset
from src.textrcnn import textrcnn
set_seed(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='textrcnn')
parser.add_argument('--ckpt_path', type=str)
args = parser.parse_args()
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend")
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
cell=cfg.cell, batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
loss_cb = LossMonitor()
print("============== Starting Testing ==============")
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, 1, False)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
network.set_train(False)
model = Model(network, loss, opt, metrics={'acc': Accuracy()}, amp_level='O3')
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("============== Accuracy:{} ==============".format(acc))

View File

@ -0,0 +1,144 @@
# TextRCNN
## Contents
- [TextRCNN Description](#textrcnn-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [ModelZoo Homepage](#modelzoo-homepage)
## [TextRCNN Description](#contents)
TextRCNN, a model for text classification, which is proposed by the Chinese Academy of Sciences in 2015.
TextRCNN actually combines RNN and CNN, first uses bidirectional RNN to obtain upper semantic and grammatical information of the input text,
and then uses maximum pooling to automatically filter out the most important feature.
Then connect a fully connected layer for classification.
The TextCNN network structure contains a convolutional layer and a pooling layer. In RCNN, the feature extraction function of the convolutional layer is replaced by RNN. The overall structure consists of RNN and pooling layer, so it is called RCNN.
[Paper](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552): Siwei Lai, Liheng Xu, Kang Liu, Jun Zhao: Recurrent Convolutional Neural Networks for Text Classification. AAAI 2015: 2267-2273
## [Model Architecture](#contents)
Specifically, the TextRCNN is mainly composed of three parts: a recurrent structure layer, a max-pooling layer, and a fully connected layer. In the paper, the length of the word vector $|e|=50$, the length of the context vector $|c|=50$, the hidden layer size $ H=100$, the learning rate $\alpha=0.01$, the amount of words is $|V|$, the input is a sequence of words, and the output is a vector containing categories.
## [Dataset](#contents)
Dataset used: [Sentence polarity dataset v1.0](<http://www.cs.cornell.edu/people/pabo/movie-review-data/>)
- Dataset size10662 movie comments in 2 classes, 9596 comments for train set, 1066 comments for test set.
- Data formattext files. The processed data is in ```./data/```
## [Environment Requirements](#contents)
- Hardware: Ascend
- Framework: [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below[MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html), [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html).
## [Quick Start](#contents)
- Preparing enviroment
```python
# download the pretrained GoogleNews-vectors-negative300.bin, put it into /tmp
# you can download from https://code.google.com/archive/p/word2vec/,
# or from https://pan.baidu.com/s/1NC2ekA_bJ0uSL7BF3SjhIg, code: yk9a
mv /tmp/GoogleNews-vectors-negative300.bin ./word2vec/
```
- Preparing data
```python
# split the dataset by the following scripts.
mkdir -p data/test && mkdir -p data/train
python data_helpers.py --task dataset_split --data_dir dataset_dir
```
- Modify the source code in ```mindspore/train/model.py```, line 173, add "O3".
```python
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3"])
```
- Runing on Ascend
```python
# run training
DEVICE_ID=7 python train.py
# or you can use the shell script to train in background
bash scripts/run_train.sh
# run evaluating
DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-10_149.ckpt
# or you can use the shell script to evaluate in background
bash scripts/run_eval.sh
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
```python
├── model_zoo
├── README.md // descriptions about all the models
├── textrcnn
├── README.md // descriptions about TextRCNN
├── data_src
│ ├──rt-polaritydata // directory to save the source data
│ ├──rt-polaritydata.README.1.0.txt // readme file of dataset
├── scripts
│ ├──run_train.sh // shell script for train on Ascend
│ ├──run_eval.sh // shell script for evaluation on Ascend
│ ├──sample.txt // example shell to run the above the two scripts
├── src
│ ├──dataset.py // creating dataset
│ ├──textrcnn.py // textrcnn architecture
│ ├──config.py // parameter configuration
├── train.py // training script
├── eval.py // evaluation script
├── data_helpers.py // dataset split script
├── sample.txt // the shell to train and eval the model without scripts
```
### [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- config for Textrcnn, Sentence polarity dataset v1.0.
```python
'num_epochs': 10, # total training epochs
'batch_size': 64, # training batch size
'cell': 'lstm', # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'.
'opt': 'adam', # the optimizer strategy, can be 'adam' or 'momentum'
'ckpt_folder_path': './ckpt', # the path to save the checkpoints
'preprocess_path': './preprocess', # the directory to save the processed data
'preprocess' : 'false', # whethere to preprocess the data
'data_path': './data/', # the path to store the splited data
'lr': 1e-3, # the training learning rate
'emb_path': './word2vec', # the directory to save the embedding file
'embed_size': 300, # the dimension of the word embedding
'save_checkpoint_steps': 149, # per step to save the checkpoint
'keep_checkpoint_max': 10, # max checkpoints to save
'momentum': 0.9 # the momentum rate
```
### Performance
| Model | MindSpore + Ascend | TensorFlow+GPU |
| -------------------------- | ----------------------------- | ------------------------- |
| Resource | Ascend 910 | NV SMX2 V100-32G |
| Version | 1.0.1 | 1.4.0 |
| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 |
| batch_size | 64 | 64 |
| Accuracy | 0.78 | 0.78 |
| Speed | 78ms/step | 89ms/step |
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,2 @@
DEVICE_ID=7 python train.py
DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-1_149.ckpt

View File

@ -0,0 +1,20 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -u unlimited
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
python ${BASEPATH}/../eval.py > --ckpt_path $1 ./eval.log 2>&1 &

View File

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -u unlimited
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
python ${BASEPATH}/../train.py > ./train.log 2>&1 &

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config
"""
from easydict import EasyDict as edict
# LSTM CONFIG
textrcnn_cfg = edict({
'pos_dir': 'data/rt-polaritydata/rt-polarity.pos',
'neg_dir': 'data/rt-polaritydata/rt-polarity.neg',
'num_epochs': 10,
'batch_size': 64,
'cell': 'lstm',
'opt': 'adam',
'ckpt_folder_path': './ckpt',
'preprocess_path': './preprocess',
'preprocess': 'false',
'data_path': './data/',
'lr': 1e-3,
'emb_path': './word2vec',
'embed_size': 300,
'save_checkpoint_steps': 149,
'keep_checkpoint_max': 10,
'momentum': 0.9
})

View File

@ -0,0 +1,179 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset api"""
import os
from itertools import chain
import gensim
import numpy as np
from mindspore.mindrecord import FileWriter
import mindspore.dataset as ds
# preprocess part
def encode_samples(tokenized_samples, word_to_idx):
""" encode word to index """
features = []
for sample in tokenized_samples:
feature = []
for token in sample:
if token in word_to_idx:
feature.append(word_to_idx[token])
else:
feature.append(0)
features.append(feature)
return features
def pad_samples(features, maxlen=50, pad=0):
""" pad all features to the same length """
padded_features = []
for feature in features:
if len(feature) >= maxlen:
padded_feature = feature[:maxlen]
else:
padded_feature = feature
while len(padded_feature) < maxlen:
padded_feature.append(pad)
padded_features.append(padded_feature)
return padded_features
def read_imdb(path, seg='train'):
""" read imdb dataset """
pos_or_neg = ['pos', 'neg']
data = []
for label in pos_or_neg:
f = os.path.join(path, seg, label)
rf = open(f, 'r')
for line in rf:
line = line.strip()
if label == 'pos':
data.append([line, 1])
elif label == 'neg':
data.append([line, 0])
return data
def tokenizer(text):
return [tok.lower() for tok in text.split(' ')]
def collect_weight(glove_path, vocab, word_to_idx, embed_size):
""" collect weight """
vocab_size = len(vocab)
# wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
# binary=False, encoding='utf-8')
wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
'GoogleNews-vectors-negative300.bin'), binary=True)
weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
idx_to_word[0] = '<unk>'
for i in range(len(wvmodel.index2word)):
try:
index = word_to_idx[wvmodel.index2word[i]]
except KeyError:
continue
weight_np[index, :] = wvmodel.get_vector(
idx_to_word[word_to_idx[wvmodel.index2word[i]]])
return weight_np
def preprocess(data_path, glove_path, embed_size):
""" preprocess the train and test data """
train_data = read_imdb(data_path, 'train')
test_data = read_imdb(data_path, 'test')
train_tokenized = []
test_tokenized = []
for review, _ in train_data:
train_tokenized.append(tokenizer(review))
for review, _ in test_data:
test_tokenized.append(tokenizer(review))
vocab = set(chain(*train_tokenized))
vocab_size = len(vocab)
print("vocab_size: ", vocab_size)
word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
word_to_idx['<unk>'] = 0
train_features = np.array(pad_samples(encode_samples(train_tokenized, word_to_idx))).astype(np.int32)
train_labels = np.array([score for _, score in train_data]).astype(np.int32)
test_features = np.array(pad_samples(encode_samples(test_tokenized, word_to_idx))).astype(np.int32)
test_labels = np.array([score for _, score in test_data]).astype(np.int32)
weight_np = collect_weight(glove_path, vocab, word_to_idx, embed_size)
return train_features, train_labels, test_features, test_labels, weight_np, vocab_size
def get_imdb_data(labels_data, features_data):
data_list = []
for i, (label, feature) in enumerate(zip(labels_data, features_data)):
data_json = {"id": i,
"label": int(label),
"feature": feature.reshape(-1)}
data_list.append(data_json)
return data_list
def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
""" convert imdb dataset to mindrecord """
num_shard = 4
train_features, train_labels, test_features, test_labels, weight_np, _ = \
preprocess(data_path, glove_path, embed_size)
np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\
weight_np.shape, "type:", train_labels.dtype)
# write mindrecord
schema_json = {"id": {"type": "int32"},
"label": {"type": "int32"},
"feature": {"type": "int32", "shape": [-1]}}
writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_train.mindrecord'), num_shard)
data = get_imdb_data(train_labels, train_features)
writer.add_schema(schema_json, "nlp_schema")
writer.add_index(["id", "label"])
writer.write_raw_data(data)
writer.commit()
writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_test.mindrecord'), num_shard)
data = get_imdb_data(test_labels, test_features)
writer.add_schema(schema_json, "nlp_schema")
writer.add_index(["id", "label"])
writer.write_raw_data(data)
writer.commit()
def create_dataset(base_path, batch_size, num_epochs, is_train):
"""Create dataset for training."""
columns_list = ["feature", "label"]
num_consumer = 4
if is_train:
path = os.path.join(base_path, 'aclImdb_train.mindrecord0')
else:
path = os.path.join(base_path, 'aclImdb_test.mindrecord0')
data_set = ds.MindDataset(path, columns_list, num_consumer)
ds.config.set_seed(1)
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
data_set = data_set.batch(batch_size, drop_remainder=True)
return data_set

View File

@ -0,0 +1,196 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""model textrcnn"""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore import Tensor
from mindspore.common import dtype as mstype
class textrcnn(nn.Cell):
"""class textrcnn"""
def __init__(self, weight, vocab_size, cell, batch_size):
super(textrcnn, self).__init__()
self.num_hiddens = 512
self.embed_size = 300
self.num_classes = 2
self.batch_size = batch_size
k = (1 / self.num_hiddens) ** 0.5
self.embedding = nn.Embedding(vocab_size, self.embed_size, embedding_table=weight)
self.embedding.embedding_table.requires_grad = False
self.cell = cell
self.cast = P.Cast()
self.h1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
self.c1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
if cell == "lstm":
self.lstm = P.DynamicRNN(forget_bias=0.0)
self.w1_fw = Parameter(
np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
np.float16), name="w1_fw")
self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
name="b1_fw")
self.w1_bw = Parameter(
np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
np.float16), name="w1_bw")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
name="b1_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
if cell == "vanilla":
self.rnnW_fw = nn.Dense(self.num_hiddens, self.num_hiddens)
self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens)
self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens)
self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens)
if cell == "gru":
self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
self.rnnWz_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
self.rnnWh_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
self.rnnWr_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
self.rnnWz_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
self.rnnWh_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
self.ones = Tensor(np.ones(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
self.transpose = P.Transpose()
self.reduce_max = P.ReduceMax()
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.reshape = P.Reshape()
self.left_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
self.right_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
self.output_dense = nn.Dense(self.num_hiddens * 1, 2)
self.concat0 = P.Concat(0)
self.concat2 = P.Concat(2)
self.concat1 = P.Concat(1)
self.text_rep_dense = nn.Dense(2 * self.num_hiddens + self.embed_size, self.num_hiddens)
self.mydense = nn.Dense(self.num_hiddens, 2)
self.drop_out = nn.Dropout(keep_prob=0.7)
self.tanh = P.Tanh()
self.sigmoid = P.Sigmoid()
self.slice = P.Slice()
# self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0)
def construct(self, x):
"""class construction"""
# x: bs, sl
output_fw = x
output_bw = x
if self.cell == "vanilla":
x = self.embedding(x) # bs, sl, emb_size
x = self.cast(x, mstype.float16)
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
for i in range(1, F.shape(x)[0]):
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
h1_after_expand_fw = self.expand_dims(h1_fw, 0)
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
for i in range(F.shape(x)[0] - 2, -1, -1):
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
h1_after_expand_bw = self.expand_dims(h1_bw, 0)
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
if self.cell == "gru":
x = self.embedding(x) # bs, sl, emb_size
x = self.cast(x, mstype.float16)
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
h_fw = self.cast(self.h1, mstype.float16)
h_x_fw = self.concat1((h_fw, x[0, :, :]))
r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[0, :, :]))))
h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
output_fw = self.expand_dims(h_fw, 0)
for i in range(1, F.shape(x)[0]):
h_x_fw = self.concat1((h_fw, x[i, :, :]))
r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[i, :, :]))))
h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
h_after_expand_fw = self.expand_dims(h_fw, 0)
output_fw = self.concat((output_fw, h_after_expand_fw))
output_fw = self.cast(output_fw, mstype.float16)
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[F.shape(x)[0] - 1, :, :]))))
h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
output_bw = self.expand_dims(h_bw, 0)
for i in range(F.shape(x)[0] - 2, -1, -1):
h_x_bw = self.concat1((h_bw, x[i, :, :]))
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[i, :, :]))))
h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
h_after_expand_bw = self.expand_dims(h_bw, 0)
output_bw = self.concat((h_after_expand_bw, output_bw))
output_bw = self.cast(output_bw, mstype.float16)
if self.cell == 'lstm':
x = self.embedding(x) # bs, sl, emb_size
x = self.cast(x, mstype.float16)
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
x = self.drop_out(x) # sl,bs, emb_size
h1_fw_init = self.h1 # bs, num_hidden
c1_fw_init = self.c1 # bs, num_hidden
_, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
h1_bw_init = self.h1 # bs, num_hidden
c1_bw_init = self.c1 # bs, num_hidden
_, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
output = self.cast(output, mstype.float16)
output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
output = self.reduce_max(output, 0) # bs, num_hidden
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
return outputs

View File

@ -0,0 +1,74 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""model train script"""
import os
import shutil
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.common import set_seed
from src.config import textrcnn_cfg as cfg
from src.dataset import create_dataset
from src.dataset import convert_to_mindrecord
from src.textrcnn import textrcnn
set_seed(1)
if __name__ == '__main__':
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend")
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if cfg.preprocess == 'true':
print("============== Starting Data Pre-processing ==============")
if os.path.exists(cfg.preprocess_path):
shutil.rmtree(cfg.preprocess_path)
os.mkdir(cfg.preprocess_path)
convert_to_mindrecord(cfg.embed_size, cfg.data_path, cfg.preprocess_path, cfg.emb_path)
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
cell=cfg.cell, batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
if cfg.opt == "adam":
opt = nn.Adam(params=network.trainable_params(), learning_rate=cfg.lr)
elif cfg.opt == "momentum":
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
loss_cb = LossMonitor()
model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3")
print("============== Starting Training ==============")
ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, \
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
model.train(cfg.num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb])
print("train success")