forked from mindspore-Ecosystem/mindspore
!1486 add train and eval script for LSTM
Merge pull request !1486 from caojian05/ms_r0.3_dev
This commit is contained in:
commit
f51a745931
|
@ -0,0 +1,100 @@
|
|||
# LSTM Example
|
||||
|
||||
## Description
|
||||
|
||||
This example is for LSTM model training and evaluation.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Download the dataset aclImdb_v1.
|
||||
|
||||
> Unzip the aclImdb_v1 dataset to any path you want and the folder structure should be as follows:
|
||||
> ```
|
||||
> .
|
||||
> ├── train # train dataset
|
||||
> └── test # infer dataset
|
||||
> ```
|
||||
|
||||
- Download the GloVe file.
|
||||
|
||||
> Unzip the glove.6B.zip to any path you want and the folder structure should be as follows:
|
||||
> ```
|
||||
> .
|
||||
> ├── glove.6B.100d.txt
|
||||
> ├── glove.6B.200d.txt
|
||||
> ├── glove.6B.300d.txt # we will use this one later.
|
||||
> └── glove.6B.50d.txt
|
||||
> ```
|
||||
|
||||
> Adding a new line at the beginning of the file which named `glove.6B.300d.txt`.
|
||||
> It means reading a total of 400,000 words, each represented by a 300-latitude word vector.
|
||||
> ```
|
||||
> 400000 300
|
||||
> ```
|
||||
|
||||
## Running the Example
|
||||
|
||||
### Training
|
||||
|
||||
```
|
||||
python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path > out.train.log 2>&1 &
|
||||
```
|
||||
The python command above will run in the background, you can view the results through the file `out.train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the script folder by default.
|
||||
|
||||
You will get the loss value as following:
|
||||
```
|
||||
# grep "loss is " out.train.log
|
||||
epoch: 1 step: 390, loss is 0.6003723
|
||||
epcoh: 2 step: 390, loss is 0.35312173
|
||||
...
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```
|
||||
python eval.py --ckpt_path=./lstm-20-390.ckpt > out.eval.log 2>&1 &
|
||||
```
|
||||
The above python command will run in the background, you can view the results through the file `out.eval.log`.
|
||||
|
||||
You will get the accuracy as following:
|
||||
```
|
||||
# grep "acc" out.eval.log
|
||||
result: {'acc': 0.83}
|
||||
```
|
||||
|
||||
## Usage:
|
||||
|
||||
### Training
|
||||
```
|
||||
usage: train.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH]
|
||||
[--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH]
|
||||
[--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}]
|
||||
|
||||
parameters/options:
|
||||
--preprocess whether to preprocess data.
|
||||
--aclimdb_path path where the dataset is stored.
|
||||
--glove_path path where the GloVe is stored.
|
||||
--preprocess_path path where the pre-process data is stored.
|
||||
--ckpt_path the path to save the checkpoint file.
|
||||
--device_target the target device to run, support "GPU", "CPU".
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```
|
||||
usage: eval.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH]
|
||||
[--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH]
|
||||
[--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}]
|
||||
|
||||
parameters/options:
|
||||
--preprocess whether to preprocess data.
|
||||
--aclimdb_path path where the dataset is stored.
|
||||
--glove_path path where the GloVe is stored.
|
||||
--preprocess_path path where the pre-process data is stored.
|
||||
--ckpt_path the checkpoint file path used to evaluate model.
|
||||
--device_target the target device to run, support "GPU", "CPU".
|
||||
```
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# LSTM CONFIG
|
||||
lstm_cfg = edict({
|
||||
'num_classes': 2,
|
||||
'learning_rate': 0.1,
|
||||
'momentum': 0.9,
|
||||
'num_epochs': 20,
|
||||
'batch_size': 64,
|
||||
'embed_size': 300,
|
||||
'num_hiddens': 100,
|
||||
'num_layers': 2,
|
||||
'bidirectional': True,
|
||||
'save_checkpoint_steps': 390,
|
||||
'keep_checkpoint_max': 10
|
||||
})
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from imdb import ImdbParser
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
||||
def create_dataset(data_home, batch_size, repeat_num=1, training=True):
|
||||
"""Data operations."""
|
||||
ds.config.set_seed(1)
|
||||
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0")
|
||||
if not training:
|
||||
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0")
|
||||
|
||||
data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4)
|
||||
|
||||
# apply map operations on images
|
||||
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
|
||||
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
|
||||
data_set = data_set.repeat(count=repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def _convert_to_mindrecord(data_home, features, labels, weight_np=None, training=True):
|
||||
"""
|
||||
convert imdb dataset to mindrecoed dataset
|
||||
"""
|
||||
if weight_np is not None:
|
||||
np.savetxt(os.path.join(data_home, 'weight.txt'), weight_np)
|
||||
|
||||
# write mindrecord
|
||||
schema_json = {"id": {"type": "int32"},
|
||||
"label": {"type": "int32"},
|
||||
"feature": {"type": "int32", "shape": [-1]}}
|
||||
|
||||
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord")
|
||||
if not training:
|
||||
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord")
|
||||
|
||||
def get_imdb_data(features, labels):
|
||||
data_list = []
|
||||
for i, (label, feature) in enumerate(zip(labels, features)):
|
||||
data_json = {"id": i,
|
||||
"label": int(label),
|
||||
"feature": feature.reshape(-1)}
|
||||
data_list.append(data_json)
|
||||
return data_list
|
||||
|
||||
writer = FileWriter(data_dir, shard_num=4)
|
||||
data = get_imdb_data(features, labels)
|
||||
writer.add_schema(schema_json, "nlp_schema")
|
||||
writer.add_index(["id", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
def convert_to_mindrecord(embed_size, aclimdb_path, preprocess_path, glove_path):
|
||||
"""
|
||||
convert imdb dataset to mindrecoed dataset
|
||||
"""
|
||||
parser = ImdbParser(aclimdb_path, glove_path, embed_size)
|
||||
parser.parse()
|
||||
|
||||
if not os.path.exists(preprocess_path):
|
||||
print(f"preprocess path {preprocess_path} is not exist")
|
||||
os.makedirs(preprocess_path)
|
||||
|
||||
train_features, train_labels, train_weight_np = parser.get_datas('train')
|
||||
_convert_to_mindrecord(preprocess_path, train_features, train_labels, train_weight_np)
|
||||
|
||||
test_features, test_labels, _ = parser.get_datas('test')
|
||||
_convert_to_mindrecord(preprocess_path, test_features, test_labels, training=False)
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train lstm example on aclImdb########################
|
||||
python eval.py --ckpt_path=./lstm-20-390.ckpt
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import lstm_cfg as cfg
|
||||
from dataset import create_dataset, convert_to_mindrecord
|
||||
from mindspore import Tensor, nn, Model, context
|
||||
from mindspore.model_zoo.lstm import SentimentNet
|
||||
from mindspore.nn import Accuracy
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='MindSpore LSTM Example')
|
||||
parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'],
|
||||
help='whether to preprocess data.')
|
||||
parser.add_argument('--aclimdb_path', type=str, default="./aclImdb",
|
||||
help='path where the dataset is stored.')
|
||||
parser.add_argument('--glove_path', type=str, default="./glove",
|
||||
help='path where the GloVe is stored.')
|
||||
parser.add_argument('--preprocess_path', type=str, default="./preprocess",
|
||||
help='path where the pre-process data is stored.')
|
||||
parser.add_argument('--ckpt_path', type=str, default=None,
|
||||
help='the checkpoint file path used to evaluate model.')
|
||||
parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'],
|
||||
help='the target device to run, support "GPU", "CPU". Default: "GPU".')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target=args.device_target)
|
||||
|
||||
if args.preprocess == "true":
|
||||
print("============== Starting Data Pre-processing ==============")
|
||||
convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path)
|
||||
|
||||
embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
network = SentimentNet(vocab_size=embedding_table.shape[0],
|
||||
embed_size=cfg.embed_size,
|
||||
num_hiddens=cfg.num_hiddens,
|
||||
num_layers=cfg.num_layers,
|
||||
bidirectional=cfg.bidirectional,
|
||||
num_classes=cfg.num_classes,
|
||||
weight=Tensor(embedding_table),
|
||||
batch_size=cfg.batch_size)
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
model = Model(network, loss, opt, {'acc': Accuracy()})
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, training=False)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
if args.device_target == "CPU":
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
else:
|
||||
acc = model.eval(ds_eval)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
imdb dataset parser.
|
||||
"""
|
||||
import os
|
||||
from itertools import chain
|
||||
|
||||
import gensim
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ImdbParser():
|
||||
"""
|
||||
parse aclImdb data to features and labels.
|
||||
sentence->tokenized->encoded->padding->features
|
||||
"""
|
||||
|
||||
def __init__(self, imdb_path, glove_path, embed_size=300):
|
||||
self.__segs = ['train', 'test']
|
||||
self.__label_dic = {'pos': 1, 'neg': 0}
|
||||
self.__imdb_path = imdb_path
|
||||
self.__glove_dim = embed_size
|
||||
self.__glove_file = os.path.join(glove_path, 'glove.6B.' + str(self.__glove_dim) + 'd.txt')
|
||||
|
||||
# properties
|
||||
self.__imdb_datas = {}
|
||||
self.__features = {}
|
||||
self.__labels = {}
|
||||
self.__vacab = {}
|
||||
self.__word2idx = {}
|
||||
self.__weight_np = {}
|
||||
self.__wvmodel = None
|
||||
|
||||
def parse(self):
|
||||
"""
|
||||
parse imdb data to memory
|
||||
"""
|
||||
self.__wvmodel = gensim.models.KeyedVectors.load_word2vec_format(self.__glove_file)
|
||||
|
||||
for seg in self.__segs:
|
||||
self.__parse_imdb_datas(seg)
|
||||
self.__parse_features_and_labels(seg)
|
||||
self.__gen_weight_np(seg)
|
||||
|
||||
def __parse_imdb_datas(self, seg):
|
||||
"""
|
||||
load data from txt
|
||||
"""
|
||||
data_lists = []
|
||||
for label_name, label_id in self.__label_dic.items():
|
||||
sentence_dir = os.path.join(self.__imdb_path, seg, label_name)
|
||||
for file in os.listdir(sentence_dir):
|
||||
with open(os.path.join(sentence_dir, file), mode='r', encoding='utf8') as f:
|
||||
sentence = f.read().replace('\n', '')
|
||||
data_lists.append([sentence, label_id])
|
||||
self.__imdb_datas[seg] = data_lists
|
||||
|
||||
def __parse_features_and_labels(self, seg):
|
||||
"""
|
||||
parse features and labels
|
||||
"""
|
||||
features = []
|
||||
labels = []
|
||||
for sentence, label in self.__imdb_datas[seg]:
|
||||
features.append(sentence)
|
||||
labels.append(label)
|
||||
|
||||
self.__features[seg] = features
|
||||
self.__labels[seg] = labels
|
||||
|
||||
# update feature to tokenized
|
||||
self.__updata_features_to_tokenized(seg)
|
||||
# parse vacab
|
||||
self.__parse_vacab(seg)
|
||||
# encode feature
|
||||
self.__encode_features(seg)
|
||||
# padding feature
|
||||
self.__padding_features(seg)
|
||||
|
||||
def __updata_features_to_tokenized(self, seg):
|
||||
tokenized_features = []
|
||||
for sentence in self.__features[seg]:
|
||||
tokenized_sentence = [word.lower() for word in sentence.split(" ")]
|
||||
tokenized_features.append(tokenized_sentence)
|
||||
self.__features[seg] = tokenized_features
|
||||
|
||||
def __parse_vacab(self, seg):
|
||||
# vocab
|
||||
tokenized_features = self.__features[seg]
|
||||
vocab = set(chain(*tokenized_features))
|
||||
self.__vacab[seg] = vocab
|
||||
|
||||
# word_to_idx: {'hello': 1, 'world':111, ... '<unk>': 0}
|
||||
word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
|
||||
word_to_idx['<unk>'] = 0
|
||||
self.__word2idx[seg] = word_to_idx
|
||||
|
||||
def __encode_features(self, seg):
|
||||
""" encode word to index """
|
||||
word_to_idx = self.__word2idx['train']
|
||||
encoded_features = []
|
||||
for tokenized_sentence in self.__features[seg]:
|
||||
encoded_sentence = []
|
||||
for word in tokenized_sentence:
|
||||
encoded_sentence.append(word_to_idx.get(word, 0))
|
||||
encoded_features.append(encoded_sentence)
|
||||
self.__features[seg] = encoded_features
|
||||
|
||||
def __padding_features(self, seg, maxlen=500, pad=0):
|
||||
""" pad all features to the same length """
|
||||
padded_features = []
|
||||
for feature in self.__features[seg]:
|
||||
if len(feature) >= maxlen:
|
||||
padded_feature = feature[:maxlen]
|
||||
else:
|
||||
padded_feature = feature
|
||||
while len(padded_feature) < maxlen:
|
||||
padded_feature.append(pad)
|
||||
padded_features.append(padded_feature)
|
||||
self.__features[seg] = padded_features
|
||||
|
||||
def __gen_weight_np(self, seg):
|
||||
"""
|
||||
generate weight by gensim
|
||||
"""
|
||||
weight_np = np.zeros((len(self.__word2idx[seg]), self.__glove_dim), dtype=np.float32)
|
||||
for word, idx in self.__word2idx[seg].items():
|
||||
if word not in self.__wvmodel:
|
||||
continue
|
||||
word_vector = self.__wvmodel.get_vector(word)
|
||||
weight_np[idx, :] = word_vector
|
||||
|
||||
self.__weight_np[seg] = weight_np
|
||||
|
||||
def get_datas(self, seg):
|
||||
"""
|
||||
return features, labels, and weight
|
||||
"""
|
||||
features = np.array(self.__features[seg]).astype(np.int32)
|
||||
labels = np.array(self.__labels[seg]).astype(np.int32)
|
||||
weight = np.array(self.__weight_np[seg])
|
||||
return features, labels, weight
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train lstm example on aclImdb########################
|
||||
python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import lstm_cfg as cfg
|
||||
from dataset import convert_to_mindrecord
|
||||
from dataset import create_dataset
|
||||
from mindspore import Tensor, nn, Model, context
|
||||
from mindspore.model_zoo.lstm import SentimentNet
|
||||
from mindspore.nn import Accuracy
|
||||
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='MindSpore LSTM Example')
|
||||
parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'],
|
||||
help='whether to preprocess data.')
|
||||
parser.add_argument('--aclimdb_path', type=str, default="./aclImdb",
|
||||
help='path where the dataset is stored.')
|
||||
parser.add_argument('--glove_path', type=str, default="./glove",
|
||||
help='path where the GloVe is stored.')
|
||||
parser.add_argument('--preprocess_path', type=str, default="./preprocess",
|
||||
help='path where the pre-process data is stored.')
|
||||
parser.add_argument('--ckpt_path', type=str, default="./",
|
||||
help='the path to save the checkpoint file.')
|
||||
parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'],
|
||||
help='the target device to run, support "GPU", "CPU". Default: "GPU".')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target=args.device_target)
|
||||
|
||||
if args.preprocess == "true":
|
||||
print("============== Starting Data Pre-processing ==============")
|
||||
convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path)
|
||||
|
||||
embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
network = SentimentNet(vocab_size=embedding_table.shape[0],
|
||||
embed_size=cfg.embed_size,
|
||||
num_hiddens=cfg.num_hiddens,
|
||||
num_layers=cfg.num_layers,
|
||||
bidirectional=cfg.bidirectional,
|
||||
num_classes=cfg.num_classes,
|
||||
weight=Tensor(embedding_table),
|
||||
batch_size=cfg.batch_size)
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
model = Model(network, loss, opt, {'acc': Accuracy()})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
if args.device_target == "CPU":
|
||||
model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False)
|
||||
else:
|
||||
model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
print("============== Training Success ==============")
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""LSTM."""
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Parameter, Tensor, nn
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def init_lstm_weight(
|
||||
input_size,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
bidirectional,
|
||||
has_bias=True):
|
||||
"""Initialize lstm weight."""
|
||||
num_directions = 1
|
||||
if bidirectional:
|
||||
num_directions = 2
|
||||
|
||||
weight_size = 0
|
||||
gate_size = 4 * hidden_size
|
||||
for layer in range(num_layers):
|
||||
for _ in range(num_directions):
|
||||
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
|
||||
weight_size += gate_size * input_layer_size
|
||||
weight_size += gate_size * hidden_size
|
||||
if has_bias:
|
||||
weight_size += 2 * gate_size
|
||||
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
|
||||
w = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
|
||||
|
||||
return w
|
||||
|
||||
|
||||
# Initialize short-term memory (h) and long-term memory (c) to 0
|
||||
def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
|
||||
"""init default input."""
|
||||
num_directions = 1
|
||||
if bidirectional:
|
||||
num_directions = 2
|
||||
|
||||
h = Tensor(
|
||||
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
c = Tensor(
|
||||
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
return h, c
|
||||
|
||||
|
||||
class SentimentNet(nn.Cell):
|
||||
"""Sentiment network structure."""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embed_size,
|
||||
num_hiddens,
|
||||
num_layers,
|
||||
bidirectional,
|
||||
num_classes,
|
||||
weight,
|
||||
batch_size):
|
||||
super(SentimentNet, self).__init__()
|
||||
# Mapp words to vectors
|
||||
self.embedding = nn.Embedding(vocab_size,
|
||||
embed_size,
|
||||
embedding_table=weight)
|
||||
self.embedding.embedding_table.requires_grad = False
|
||||
self.trans = P.Transpose()
|
||||
self.perm = (1, 0, 2)
|
||||
self.encoder = nn.LSTM(input_size=embed_size,
|
||||
hidden_size=num_hiddens,
|
||||
num_layers=num_layers,
|
||||
has_bias=True,
|
||||
bidirectional=bidirectional,
|
||||
dropout=0.0)
|
||||
w_init = init_lstm_weight(
|
||||
embed_size,
|
||||
num_hiddens,
|
||||
num_layers,
|
||||
bidirectional)
|
||||
self.encoder.weight = w_init
|
||||
self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
|
||||
|
||||
self.concat = P.Concat(1)
|
||||
if bidirectional:
|
||||
self.decoder = nn.Dense(num_hiddens * 4, num_classes)
|
||||
else:
|
||||
self.decoder = nn.Dense(num_hiddens * 2, num_classes)
|
||||
|
||||
def construct(self, inputs):
|
||||
# input:(64,500,300)
|
||||
embeddings = self.embedding(inputs)
|
||||
embeddings = self.trans(embeddings, self.perm)
|
||||
output, _ = self.encoder(embeddings, (self.h, self.c))
|
||||
# states[i] size(64,200) -> encoding.size(64,400)
|
||||
encoding = self.concat((output[0], output[1]))
|
||||
outputs = self.decoder(encoding)
|
||||
return outputs
|
Loading…
Reference in New Issue