forked from mindspore-Ecosystem/mindspore
add textrcnn into model zoo
Signed-off-by: zymaa <317958662@qq.com> TextRCNN, a model for text classification, which is proposed by the Chinese Academy of Sciences in 2015. TextRCNN actually combines RNN and CNN, first uses bidirectional RNN to obtain upper semantic and grammatical information of the input text, and then uses maximum pooling to automatically filter out the most important feature. Signed-off-by: zymaa <317958662@qq.com>
This commit is contained in:
parent
0c7ba7a7fa
commit
16c5013f6e
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dataset helpers api"""
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
parser = argparse.ArgumentParser(description='textrcnn')
|
||||
parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.')
|
||||
parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src')
|
||||
parser.add_argument('--out_dir', type=str, help='the target dataset directory.', default='./data')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def dataset_split(label):
|
||||
"""dataset_split api"""
|
||||
# label can be 'pos' or 'neg'
|
||||
pos_samples = []
|
||||
pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label)
|
||||
pfhand = open(pos_file, encoding='utf-8')
|
||||
pos_samples += pfhand.readlines()
|
||||
pfhand.close()
|
||||
perm = np.random.permutation(len(pos_samples))
|
||||
# print(perm[0:int(len(pos_samples)*0.8)])
|
||||
perm_train = perm[0:int(len(pos_samples)*0.9)]
|
||||
perm_test = perm[int(len(pos_samples)*0.9):]
|
||||
pos_samples_train = []
|
||||
pos_samples_test = []
|
||||
for pt in perm_train:
|
||||
pos_samples_train.append(pos_samples[pt])
|
||||
for pt in perm_test:
|
||||
pos_samples_test.append(pos_samples[pt])
|
||||
f = open(os.path.join(args.out_dir, 'train', label), "w")
|
||||
f.write(''.join(pos_samples_train))
|
||||
f.close()
|
||||
|
||||
f = open(os.path.join(args.out_dir, 'test', label), "w")
|
||||
f.write(''.join(pos_samples_test))
|
||||
f.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.task == "dataset_split":
|
||||
dataset_split('pos')
|
||||
dataset_split('neg')
|
||||
|
||||
# search(args.q)
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""model evaluation script"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import textrcnn_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.textrcnn import textrcnn
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='textrcnn')
|
||||
parser.add_argument('--ckpt_path', type=str)
|
||||
args = parser.parse_args()
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
|
||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
loss_cb = LossMonitor()
|
||||
print("============== Starting Testing ==============")
|
||||
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, 1, False)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(False)
|
||||
model = Model(network, loss, opt, metrics={'acc': Accuracy()}, amp_level='O3')
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
|
@ -0,0 +1,144 @@
|
|||
# TextRCNN
|
||||
|
||||
## Contents
|
||||
|
||||
- [TextRCNN Description](#textrcnn-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## [TextRCNN Description](#contents)
|
||||
|
||||
TextRCNN, a model for text classification, which is proposed by the Chinese Academy of Sciences in 2015.
|
||||
TextRCNN actually combines RNN and CNN, first uses bidirectional RNN to obtain upper semantic and grammatical information of the input text,
|
||||
and then uses maximum pooling to automatically filter out the most important feature.
|
||||
Then connect a fully connected layer for classification.
|
||||
|
||||
The TextCNN network structure contains a convolutional layer and a pooling layer. In RCNN, the feature extraction function of the convolutional layer is replaced by RNN. The overall structure consists of RNN and pooling layer, so it is called RCNN.
|
||||
|
||||
[Paper](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552): Siwei Lai, Liheng Xu, Kang Liu, Jun Zhao: Recurrent Convolutional Neural Networks for Text Classification. AAAI 2015: 2267-2273
|
||||
|
||||
## [Model Architecture](#contents)
|
||||
|
||||
Specifically, the TextRCNN is mainly composed of three parts: a recurrent structure layer, a max-pooling layer, and a fully connected layer. In the paper, the length of the word vector $|e|=50$, the length of the context vector $|c|=50$, the hidden layer size $ H=100$, the learning rate $\alpha=0.01$, the amount of words is $|V|$, the input is a sequence of words, and the output is a vector containing categories.
|
||||
|
||||
## [Dataset](#contents)
|
||||
|
||||
Dataset used: [Sentence polarity dataset v1.0](<http://www.cs.cornell.edu/people/pabo/movie-review-data/>)
|
||||
|
||||
- Dataset size:10662 movie comments in 2 classes, 9596 comments for train set, 1066 comments for test set.
|
||||
- Data format:text files. The processed data is in ```./data/```
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware: Ascend
|
||||
- Framework: [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:[MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html), [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html).
|
||||
|
||||
## [Quick Start](#contents)
|
||||
|
||||
- Preparing enviroment
|
||||
|
||||
```python
|
||||
# download the pretrained GoogleNews-vectors-negative300.bin, put it into /tmp
|
||||
# you can download from https://code.google.com/archive/p/word2vec/,
|
||||
# or from https://pan.baidu.com/s/1NC2ekA_bJ0uSL7BF3SjhIg, code: yk9a
|
||||
|
||||
mv /tmp/GoogleNews-vectors-negative300.bin ./word2vec/
|
||||
```
|
||||
|
||||
- Preparing data
|
||||
|
||||
```python
|
||||
# split the dataset by the following scripts.
|
||||
mkdir -p data/test && mkdir -p data/train
|
||||
python data_helpers.py --task dataset_split --data_dir dataset_dir
|
||||
|
||||
```
|
||||
|
||||
- Modify the source code in ```mindspore/train/model.py```, line 173, add "O3".
|
||||
|
||||
```python
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3"])
|
||||
```
|
||||
|
||||
- Runing on Ascend
|
||||
|
||||
```python
|
||||
# run training
|
||||
DEVICE_ID=7 python train.py
|
||||
# or you can use the shell script to train in background
|
||||
bash scripts/run_train.sh
|
||||
|
||||
# run evaluating
|
||||
DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-10_149.ckpt
|
||||
# or you can use the shell script to evaluate in background
|
||||
bash scripts/run_eval.sh
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```python
|
||||
├── model_zoo
|
||||
├── README.md // descriptions about all the models
|
||||
├── textrcnn
|
||||
├── README.md // descriptions about TextRCNN
|
||||
├── data_src
|
||||
│ ├──rt-polaritydata // directory to save the source data
|
||||
│ ├──rt-polaritydata.README.1.0.txt // readme file of dataset
|
||||
├── scripts
|
||||
│ ├──run_train.sh // shell script for train on Ascend
|
||||
│ ├──run_eval.sh // shell script for evaluation on Ascend
|
||||
│ ├──sample.txt // example shell to run the above the two scripts
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──textrcnn.py // textrcnn architecture
|
||||
│ ├──config.py // parameter configuration
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── data_helpers.py // dataset split script
|
||||
├── sample.txt // the shell to train and eval the model without scripts
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
|
||||
- config for Textrcnn, Sentence polarity dataset v1.0.
|
||||
|
||||
```python
|
||||
'num_epochs': 10, # total training epochs
|
||||
'batch_size': 64, # training batch size
|
||||
'cell': 'lstm', # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'.
|
||||
'opt': 'adam', # the optimizer strategy, can be 'adam' or 'momentum'
|
||||
'ckpt_folder_path': './ckpt', # the path to save the checkpoints
|
||||
'preprocess_path': './preprocess', # the directory to save the processed data
|
||||
'preprocess' : 'false', # whethere to preprocess the data
|
||||
'data_path': './data/', # the path to store the splited data
|
||||
'lr': 1e-3, # the training learning rate
|
||||
'emb_path': './word2vec', # the directory to save the embedding file
|
||||
'embed_size': 300, # the dimension of the word embedding
|
||||
'save_checkpoint_steps': 149, # per step to save the checkpoint
|
||||
'keep_checkpoint_max': 10, # max checkpoints to save
|
||||
'momentum': 0.9 # the momentum rate
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
| Model | MindSpore + Ascend | TensorFlow+GPU |
|
||||
| -------------------------- | ----------------------------- | ------------------------- |
|
||||
| Resource | Ascend 910 | NV SMX2 V100-32G |
|
||||
| Version | 1.0.1 | 1.4.0 |
|
||||
| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 |
|
||||
| batch_size | 64 | 64 |
|
||||
| Accuracy | 0.78 | 0.78 |
|
||||
| Speed | 78ms/step | 89ms/step |
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,2 @@
|
|||
DEVICE_ID=7 python train.py
|
||||
DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-1_149.ckpt
|
|
@ -0,0 +1,20 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
ulimit -u unlimited
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
python ${BASEPATH}/../eval.py > --ckpt_path $1 ./eval.log 2>&1 &
|
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
python ${BASEPATH}/../train.py > ./train.log 2>&1 &
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# LSTM CONFIG
|
||||
textrcnn_cfg = edict({
|
||||
'pos_dir': 'data/rt-polaritydata/rt-polarity.pos',
|
||||
'neg_dir': 'data/rt-polaritydata/rt-polarity.neg',
|
||||
'num_epochs': 10,
|
||||
'batch_size': 64,
|
||||
'cell': 'lstm',
|
||||
'opt': 'adam',
|
||||
'ckpt_folder_path': './ckpt',
|
||||
'preprocess_path': './preprocess',
|
||||
'preprocess': 'false',
|
||||
'data_path': './data/',
|
||||
'lr': 1e-3,
|
||||
'emb_path': './word2vec',
|
||||
'embed_size': 300,
|
||||
'save_checkpoint_steps': 149,
|
||||
'keep_checkpoint_max': 10,
|
||||
'momentum': 0.9
|
||||
})
|
|
@ -0,0 +1,179 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dataset api"""
|
||||
import os
|
||||
from itertools import chain
|
||||
import gensim
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset as ds
|
||||
|
||||
# preprocess part
|
||||
def encode_samples(tokenized_samples, word_to_idx):
|
||||
""" encode word to index """
|
||||
features = []
|
||||
for sample in tokenized_samples:
|
||||
feature = []
|
||||
for token in sample:
|
||||
if token in word_to_idx:
|
||||
feature.append(word_to_idx[token])
|
||||
else:
|
||||
feature.append(0)
|
||||
features.append(feature)
|
||||
return features
|
||||
|
||||
|
||||
def pad_samples(features, maxlen=50, pad=0):
|
||||
""" pad all features to the same length """
|
||||
padded_features = []
|
||||
for feature in features:
|
||||
if len(feature) >= maxlen:
|
||||
padded_feature = feature[:maxlen]
|
||||
else:
|
||||
padded_feature = feature
|
||||
while len(padded_feature) < maxlen:
|
||||
padded_feature.append(pad)
|
||||
padded_features.append(padded_feature)
|
||||
return padded_features
|
||||
|
||||
|
||||
def read_imdb(path, seg='train'):
|
||||
""" read imdb dataset """
|
||||
pos_or_neg = ['pos', 'neg']
|
||||
data = []
|
||||
for label in pos_or_neg:
|
||||
|
||||
f = os.path.join(path, seg, label)
|
||||
rf = open(f, 'r')
|
||||
for line in rf:
|
||||
line = line.strip()
|
||||
if label == 'pos':
|
||||
data.append([line, 1])
|
||||
elif label == 'neg':
|
||||
data.append([line, 0])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def tokenizer(text):
|
||||
return [tok.lower() for tok in text.split(' ')]
|
||||
|
||||
|
||||
def collect_weight(glove_path, vocab, word_to_idx, embed_size):
|
||||
""" collect weight """
|
||||
vocab_size = len(vocab)
|
||||
# wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
|
||||
# binary=False, encoding='utf-8')
|
||||
wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
|
||||
'GoogleNews-vectors-negative300.bin'), binary=True)
|
||||
weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
|
||||
|
||||
idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
|
||||
idx_to_word[0] = '<unk>'
|
||||
|
||||
for i in range(len(wvmodel.index2word)):
|
||||
try:
|
||||
index = word_to_idx[wvmodel.index2word[i]]
|
||||
except KeyError:
|
||||
continue
|
||||
weight_np[index, :] = wvmodel.get_vector(
|
||||
idx_to_word[word_to_idx[wvmodel.index2word[i]]])
|
||||
return weight_np
|
||||
|
||||
|
||||
def preprocess(data_path, glove_path, embed_size):
|
||||
""" preprocess the train and test data """
|
||||
train_data = read_imdb(data_path, 'train')
|
||||
test_data = read_imdb(data_path, 'test')
|
||||
|
||||
train_tokenized = []
|
||||
test_tokenized = []
|
||||
for review, _ in train_data:
|
||||
train_tokenized.append(tokenizer(review))
|
||||
for review, _ in test_data:
|
||||
test_tokenized.append(tokenizer(review))
|
||||
|
||||
vocab = set(chain(*train_tokenized))
|
||||
vocab_size = len(vocab)
|
||||
print("vocab_size: ", vocab_size)
|
||||
|
||||
word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
|
||||
word_to_idx['<unk>'] = 0
|
||||
|
||||
train_features = np.array(pad_samples(encode_samples(train_tokenized, word_to_idx))).astype(np.int32)
|
||||
train_labels = np.array([score for _, score in train_data]).astype(np.int32)
|
||||
test_features = np.array(pad_samples(encode_samples(test_tokenized, word_to_idx))).astype(np.int32)
|
||||
test_labels = np.array([score for _, score in test_data]).astype(np.int32)
|
||||
|
||||
weight_np = collect_weight(glove_path, vocab, word_to_idx, embed_size)
|
||||
return train_features, train_labels, test_features, test_labels, weight_np, vocab_size
|
||||
|
||||
|
||||
def get_imdb_data(labels_data, features_data):
|
||||
data_list = []
|
||||
for i, (label, feature) in enumerate(zip(labels_data, features_data)):
|
||||
data_json = {"id": i,
|
||||
"label": int(label),
|
||||
"feature": feature.reshape(-1)}
|
||||
data_list.append(data_json)
|
||||
return data_list
|
||||
|
||||
|
||||
def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
|
||||
""" convert imdb dataset to mindrecord """
|
||||
|
||||
num_shard = 4
|
||||
train_features, train_labels, test_features, test_labels, weight_np, _ = \
|
||||
preprocess(data_path, glove_path, embed_size)
|
||||
np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
|
||||
|
||||
print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\
|
||||
weight_np.shape, "type:", train_labels.dtype)
|
||||
# write mindrecord
|
||||
schema_json = {"id": {"type": "int32"},
|
||||
"label": {"type": "int32"},
|
||||
"feature": {"type": "int32", "shape": [-1]}}
|
||||
|
||||
writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_train.mindrecord'), num_shard)
|
||||
data = get_imdb_data(train_labels, train_features)
|
||||
writer.add_schema(schema_json, "nlp_schema")
|
||||
writer.add_index(["id", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_test.mindrecord'), num_shard)
|
||||
data = get_imdb_data(test_labels, test_features)
|
||||
writer.add_schema(schema_json, "nlp_schema")
|
||||
writer.add_index(["id", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
def create_dataset(base_path, batch_size, num_epochs, is_train):
|
||||
"""Create dataset for training."""
|
||||
columns_list = ["feature", "label"]
|
||||
num_consumer = 4
|
||||
|
||||
if is_train:
|
||||
path = os.path.join(base_path, 'aclImdb_train.mindrecord0')
|
||||
else:
|
||||
path = os.path.join(base_path, 'aclImdb_test.mindrecord0')
|
||||
|
||||
data_set = ds.MindDataset(path, columns_list, num_consumer)
|
||||
ds.config.set_seed(1)
|
||||
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
return data_set
|
|
@ -0,0 +1,196 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""model textrcnn"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
class textrcnn(nn.Cell):
|
||||
"""class textrcnn"""
|
||||
def __init__(self, weight, vocab_size, cell, batch_size):
|
||||
super(textrcnn, self).__init__()
|
||||
self.num_hiddens = 512
|
||||
self.embed_size = 300
|
||||
self.num_classes = 2
|
||||
self.batch_size = batch_size
|
||||
k = (1 / self.num_hiddens) ** 0.5
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, self.embed_size, embedding_table=weight)
|
||||
self.embedding.embedding_table.requires_grad = False
|
||||
self.cell = cell
|
||||
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.h1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
self.c1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
|
||||
if cell == "lstm":
|
||||
self.lstm = P.DynamicRNN(forget_bias=0.0)
|
||||
self.w1_fw = Parameter(
|
||||
np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
|
||||
np.float16), name="w1_fw")
|
||||
self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
|
||||
name="b1_fw")
|
||||
self.w1_bw = Parameter(
|
||||
np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
|
||||
np.float16), name="w1_bw")
|
||||
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
|
||||
name="b1_bw")
|
||||
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
|
||||
if cell == "vanilla":
|
||||
self.rnnW_fw = nn.Dense(self.num_hiddens, self.num_hiddens)
|
||||
self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens)
|
||||
self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens)
|
||||
self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens)
|
||||
|
||||
if cell == "gru":
|
||||
self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWz_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWh_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWr_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWz_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWh_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.ones = Tensor(np.ones(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.reduce_max = P.ReduceMax()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.concat = P.Concat()
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.left_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
self.right_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
self.output_dense = nn.Dense(self.num_hiddens * 1, 2)
|
||||
self.concat0 = P.Concat(0)
|
||||
self.concat2 = P.Concat(2)
|
||||
self.concat1 = P.Concat(1)
|
||||
self.text_rep_dense = nn.Dense(2 * self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.mydense = nn.Dense(self.num_hiddens, 2)
|
||||
self.drop_out = nn.Dropout(keep_prob=0.7)
|
||||
self.tanh = P.Tanh()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.slice = P.Slice()
|
||||
# self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0)
|
||||
|
||||
def construct(self, x):
|
||||
"""class construction"""
|
||||
# x: bs, sl
|
||||
output_fw = x
|
||||
output_bw = x
|
||||
|
||||
if self.cell == "vanilla":
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
|
||||
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
|
||||
|
||||
for i in range(1, F.shape(x)[0]):
|
||||
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
|
||||
h1_after_expand_fw = self.expand_dims(h1_fw, 0)
|
||||
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
|
||||
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
|
||||
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
|
||||
|
||||
for i in range(F.shape(x)[0] - 2, -1, -1):
|
||||
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
|
||||
h1_after_expand_bw = self.expand_dims(h1_bw, 0)
|
||||
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
|
||||
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
if self.cell == "gru":
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h_fw = self.cast(self.h1, mstype.float16)
|
||||
|
||||
h_x_fw = self.concat1((h_fw, x[0, :, :]))
|
||||
r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
|
||||
z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
|
||||
h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[0, :, :]))))
|
||||
h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
|
||||
output_fw = self.expand_dims(h_fw, 0)
|
||||
|
||||
for i in range(1, F.shape(x)[0]):
|
||||
h_x_fw = self.concat1((h_fw, x[i, :, :]))
|
||||
r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
|
||||
z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
|
||||
h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[i, :, :]))))
|
||||
h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
|
||||
h_after_expand_fw = self.expand_dims(h_fw, 0)
|
||||
output_fw = self.concat((output_fw, h_after_expand_fw))
|
||||
output_fw = self.cast(output_fw, mstype.float16)
|
||||
|
||||
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
|
||||
h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
|
||||
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
|
||||
z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
|
||||
h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[F.shape(x)[0] - 1, :, :]))))
|
||||
h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
|
||||
output_bw = self.expand_dims(h_bw, 0)
|
||||
for i in range(F.shape(x)[0] - 2, -1, -1):
|
||||
h_x_bw = self.concat1((h_bw, x[i, :, :]))
|
||||
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
|
||||
z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
|
||||
h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[i, :, :]))))
|
||||
h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
|
||||
h_after_expand_bw = self.expand_dims(h_bw, 0)
|
||||
output_bw = self.concat((h_after_expand_bw, output_bw))
|
||||
output_bw = self.cast(output_bw, mstype.float16)
|
||||
if self.cell == 'lstm':
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h1_fw_init = self.h1 # bs, num_hidden
|
||||
c1_fw_init = self.c1 # bs, num_hidden
|
||||
|
||||
_, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
|
||||
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
h1_bw_init = self.h1 # bs, num_hidden
|
||||
c1_bw_init = self.c1 # bs, num_hidden
|
||||
_, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
|
||||
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
|
||||
|
||||
c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
|
||||
c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
|
||||
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
|
||||
output = self.cast(output, mstype.float16)
|
||||
|
||||
output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
|
||||
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
|
||||
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
|
||||
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
|
||||
output = self.reduce_max(output, 0) # bs, num_hidden
|
||||
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
|
||||
return outputs
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""model train script"""
|
||||
import os
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import textrcnn_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.dataset import convert_to_mindrecord
|
||||
from src.textrcnn import textrcnn
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if cfg.preprocess == 'true':
|
||||
print("============== Starting Data Pre-processing ==============")
|
||||
if os.path.exists(cfg.preprocess_path):
|
||||
shutil.rmtree(cfg.preprocess_path)
|
||||
os.mkdir(cfg.preprocess_path)
|
||||
convert_to_mindrecord(cfg.embed_size, cfg.data_path, cfg.preprocess_path, cfg.emb_path)
|
||||
|
||||
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
|
||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
|
||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
if cfg.opt == "adam":
|
||||
opt = nn.Adam(params=network.trainable_params(), learning_rate=cfg.lr)
|
||||
elif cfg.opt == "momentum":
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
|
||||
loss_cb = LossMonitor()
|
||||
model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3")
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, \
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
|
||||
model.train(cfg.num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
|
Loading…
Reference in New Issue