!10534 add export to textrcnn, remove useless code in eval

From: @chenmai1102
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2020-12-27 15:04:12 +08:00 committed by Gitee
commit c52a4381dd
8 changed files with 77 additions and 22 deletions

View File

@ -23,7 +23,6 @@ parser.add_argument('--data_dir', type=str, help='the source dataset directory.'
parser.add_argument('--out_dir', type=str, help='the target dataset directory.', default='./data')
args = parser.parse_args()
np.random.seed(2)
def dataset_split(label):
@ -34,6 +33,7 @@ def dataset_split(label):
pfhand = open(pos_file, encoding='utf-8')
pos_samples += pfhand.readlines()
pfhand.close()
np.random.seed(0)
perm = np.random.permutation(len(pos_samples))
perm_train = perm[0:int(len(pos_samples) * 0.9)]
perm_test = perm[int(len(pos_samples) * 0.9):]

View File

@ -48,13 +48,12 @@ if __name__ == '__main__':
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=cfg.cell, batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
loss_cb = LossMonitor()
print("============== Starting Testing ==============")
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, 1, False)
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
network.set_train(False)
model = Model(network, loss, opt, metrics={'acc': Accuracy()}, amp_level='O3')
model = Model(network, loss, metrics={'acc': Accuracy()}, amp_level='O3')
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("============== Accuracy:{} ==============".format(acc))

View File

@ -0,0 +1,49 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""textrcnn export ckpt file to mindir/air"""
import os
import argparse
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.textrcnn import textrcnn
from src.config import textrcnn_cfg as config
parser = argparse.ArgumentParser(description="textrcnn")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_file", type=str, required=True, help="textrcnn ckpt file.")
parser.add_argument("--file_name", type=str, default="textrcnn", help="textrcnn output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"],
default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
if __name__ == "__main__":
# define net
embedding_table = np.loadtxt(os.path.join(config.preprocess_path, "weight.txt")).astype(np.float32)
net = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=config.cell, batch_size=config.batch_size)
# load checkpoint
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
net.set_train(False)
image = Tensor(np.ones([config.batch_size, 50], np.int32))
export(net, image, file_name=args.file_name, file_format=args.file_format)

View File

@ -100,6 +100,7 @@ bash scripts/run_eval.sh
│ ├──textrcnn.py // textrcnn architecture
│ ├──config.py // parameter configuration
├── train.py // training script
├── export.py // export script
├── eval.py // evaluation script
├── data_helpers.py // dataset split script
├── sample.txt // the shell to train and eval the model without scripts
@ -129,8 +130,7 @@ Parameters for both training and evaluation can be set in config.py
'emb_path': './word2vec', # the directory to save the embedding file
'embed_size': 300, # the dimension of the word embedding
'save_checkpoint_steps': 149, # per step to save the checkpoint
'keep_checkpoint_max': 10, # max checkpoints to save
'momentum': 0.9 # the momentum rate
'keep_checkpoint_max': 10 # max checkpoints to save
```
### Performance

View File

@ -39,5 +39,4 @@ textrcnn_cfg = edict({
'embed_size': 300,
'save_checkpoint_steps': 149,
'keep_checkpoint_max': 10,
'momentum': 0.9
})

View File

@ -76,9 +76,7 @@ def tokenizer(text):
def collect_weight(glove_path, vocab, word_to_idx, embed_size):
""" collect weight """
vocab_size = len(vocab)
# wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
# binary=False, encoding='utf-8')
wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path,
'GoogleNews-vectors-negative300.bin'),
binary=True)
weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
@ -164,7 +162,7 @@ def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
writer.commit()
def create_dataset(base_path, batch_size, num_epochs, is_train):
def create_dataset(base_path, batch_size, is_train):
"""Create dataset for training."""
columns_list = ["feature", "label"]
num_consumer = 4
@ -175,7 +173,7 @@ def create_dataset(base_path, batch_size, num_epochs, is_train):
path = os.path.join(base_path, 'aclImdb_test.mindrecord0')
data_set = ds.MindDataset(path, columns_list, num_consumer)
ds.config.set_seed(1)
ds.config.set_seed(0)
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
data_set = data_set.batch(batch_size, drop_remainder=True)
return data_set

View File

@ -47,16 +47,16 @@ class textrcnn(nn.Cell):
self.lstm = P.DynamicRNN(forget_bias=0.0)
self.w1_fw = Parameter(
np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
np.float32), name="w1_fw")
self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float32),
np.float16), name="w1_fw")
self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
name="b1_fw")
self.w1_bw = Parameter(
np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
np.float32), name="w1_bw")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float32),
np.float16), name="w1_bw")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
name="b1_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float32))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float32))
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
if cell == "vanilla":
self.rnnW_fw = nn.Dense(self.num_hiddens, self.num_hiddens)
@ -72,6 +72,12 @@ class textrcnn(nn.Cell):
self.rnnWz_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
self.rnnWh_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
self.ones = Tensor(np.ones(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
self.rnnWr_fw.to_float(mstype.float16)
self.rnnWz_fw.to_float(mstype.float16)
self.rnnWh_fw.to_float(mstype.float16)
self.rnnWr_bw.to_float(mstype.float16)
self.rnnWz_bw.to_float(mstype.float16)
self.rnnWh_bw.to_float(mstype.float16)
self.transpose = P.Transpose()
self.reduce_max = P.ReduceMax()
@ -91,6 +97,9 @@ class textrcnn(nn.Cell):
self.tanh = P.Tanh()
self.sigmoid = P.Sigmoid()
self.slice = P.Slice()
self.text_rep_dense.to_float(mstype.float16)
self.mydense.to_float(mstype.float16)
self.output_dense.to_float(mstype.float16)
def construct(self, x):
"""class construction"""

View File

@ -22,7 +22,7 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.common import set_seed
from src.config import textrcnn_cfg as cfg
@ -31,7 +31,7 @@ from src.dataset import convert_to_mindrecord
from src.textrcnn import textrcnn
from src.utils import get_lr
set_seed(2)
set_seed(0)
if __name__ == '__main__':
@ -58,7 +58,7 @@ if __name__ == '__main__':
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=cfg.cell, batch_size=cfg.batch_size)
ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True)
ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, True)
step_size = ds_train.get_dataset_size()
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
@ -70,11 +70,12 @@ if __name__ == '__main__':
opt = nn.Adam(params=network.trainable_params(), learning_rate=lr)
loss_cb = LossMonitor()
time_cb = TimeMonitor()
model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3")
print("============== Starting Training ==============")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb])
model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb, time_cb])
print("train success")