From 1fae83d74697bf29250d3e3e0c78bedd8551a347 Mon Sep 17 00:00:00 2001 From: caojian05 Date: Mon, 25 May 2020 22:16:43 +0800 Subject: [PATCH] add train and eval script for LSTM --- example/lstm_aclImdb/README.md | 100 +++++++++++++++++++++ example/lstm_aclImdb/config.py | 33 +++++++ example/lstm_aclImdb/dataset.py | 92 +++++++++++++++++++ example/lstm_aclImdb/eval.py | 81 +++++++++++++++++ example/lstm_aclImdb/imdb.py | 155 ++++++++++++++++++++++++++++++++ example/lstm_aclImdb/train.py | 83 +++++++++++++++++ mindspore/model_zoo/lstm.py | 115 ++++++++++++++++++++++++ 7 files changed, 659 insertions(+) create mode 100644 example/lstm_aclImdb/README.md create mode 100644 example/lstm_aclImdb/config.py create mode 100644 example/lstm_aclImdb/dataset.py create mode 100644 example/lstm_aclImdb/eval.py create mode 100644 example/lstm_aclImdb/imdb.py create mode 100644 example/lstm_aclImdb/train.py create mode 100644 mindspore/model_zoo/lstm.py diff --git a/example/lstm_aclImdb/README.md b/example/lstm_aclImdb/README.md new file mode 100644 index 0000000000..95ac30f3dc --- /dev/null +++ b/example/lstm_aclImdb/README.md @@ -0,0 +1,100 @@ +# LSTM Example + +## Description + +This example is for LSTM model training and evaluation. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset aclImdb_v1. + +> Unzip the aclImdb_v1 dataset to any path you want and the folder structure should be as follows: +> ``` +> . +> ├── train # train dataset +> └── test # infer dataset +> ``` + +- Download the GloVe file. + +> Unzip the glove.6B.zip to any path you want and the folder structure should be as follows: +> ``` +> . +> ├── glove.6B.100d.txt +> ├── glove.6B.200d.txt +> ├── glove.6B.300d.txt # we will use this one later. +> └── glove.6B.50d.txt +> ``` + +> Adding a new line at the beginning of the file which named `glove.6B.300d.txt`. +> It means reading a total of 400,000 words, each represented by a 300-latitude word vector. +> ``` +> 400000 300 +> ``` + +## Running the Example + +### Training + +``` +python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path > out.train.log 2>&1 & +``` +The python command above will run in the background, you can view the results through the file `out.train.log`. + +After training, you'll get some checkpoint files under the script folder by default. + +You will get the loss value as following: +``` +# grep "loss is " out.train.log +epoch: 1 step: 390, loss is 0.6003723 +epcoh: 2 step: 390, loss is 0.35312173 +... +``` + +### Evaluation + +``` +python eval.py --ckpt_path=./lstm-20-390.ckpt > out.eval.log 2>&1 & +``` +The above python command will run in the background, you can view the results through the file `out.eval.log`. + +You will get the accuracy as following: +``` +# grep "acc" out.eval.log +result: {'acc': 0.83} +``` + +## Usage: + +### Training +``` +usage: train.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH] + [--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH] + [--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}] + +parameters/options: + --preprocess whether to preprocess data. + --aclimdb_path path where the dataset is stored. + --glove_path path where the GloVe is stored. + --preprocess_path path where the pre-process data is stored. + --ckpt_path the path to save the checkpoint file. + --device_target the target device to run, support "GPU", "CPU". +``` + +### Evaluation + +``` +usage: eval.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH] + [--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH] + [--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}] + +parameters/options: + --preprocess whether to preprocess data. + --aclimdb_path path where the dataset is stored. + --glove_path path where the GloVe is stored. + --preprocess_path path where the pre-process data is stored. + --ckpt_path the checkpoint file path used to evaluate model. + --device_target the target device to run, support "GPU", "CPU". +``` diff --git a/example/lstm_aclImdb/config.py b/example/lstm_aclImdb/config.py new file mode 100644 index 0000000000..688760111c --- /dev/null +++ b/example/lstm_aclImdb/config.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting +""" +from easydict import EasyDict as edict + +# LSTM CONFIG +lstm_cfg = edict({ + 'num_classes': 2, + 'learning_rate': 0.1, + 'momentum': 0.9, + 'num_epochs': 20, + 'batch_size': 64, + 'embed_size': 300, + 'num_hiddens': 100, + 'num_layers': 2, + 'bidirectional': True, + 'save_checkpoint_steps': 390, + 'keep_checkpoint_max': 10 +}) diff --git a/example/lstm_aclImdb/dataset.py b/example/lstm_aclImdb/dataset.py new file mode 100644 index 0000000000..24797198e0 --- /dev/null +++ b/example/lstm_aclImdb/dataset.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import os + +import numpy as np + +from imdb import ImdbParser +import mindspore.dataset as ds +from mindspore.mindrecord import FileWriter + + +def create_dataset(data_home, batch_size, repeat_num=1, training=True): + """Data operations.""" + ds.config.set_seed(1) + data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0") + if not training: + data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0") + + data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4) + + # apply map operations on images + data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size()) + data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) + data_set = data_set.repeat(count=repeat_num) + + return data_set + + +def _convert_to_mindrecord(data_home, features, labels, weight_np=None, training=True): + """ + convert imdb dataset to mindrecoed dataset + """ + if weight_np is not None: + np.savetxt(os.path.join(data_home, 'weight.txt'), weight_np) + + # write mindrecord + schema_json = {"id": {"type": "int32"}, + "label": {"type": "int32"}, + "feature": {"type": "int32", "shape": [-1]}} + + data_dir = os.path.join(data_home, "aclImdb_train.mindrecord") + if not training: + data_dir = os.path.join(data_home, "aclImdb_test.mindrecord") + + def get_imdb_data(features, labels): + data_list = [] + for i, (label, feature) in enumerate(zip(labels, features)): + data_json = {"id": i, + "label": int(label), + "feature": feature.reshape(-1)} + data_list.append(data_json) + return data_list + + writer = FileWriter(data_dir, shard_num=4) + data = get_imdb_data(features, labels) + writer.add_schema(schema_json, "nlp_schema") + writer.add_index(["id", "label"]) + writer.write_raw_data(data) + writer.commit() + + +def convert_to_mindrecord(embed_size, aclimdb_path, preprocess_path, glove_path): + """ + convert imdb dataset to mindrecoed dataset + """ + parser = ImdbParser(aclimdb_path, glove_path, embed_size) + parser.parse() + + if not os.path.exists(preprocess_path): + print(f"preprocess path {preprocess_path} is not exist") + os.makedirs(preprocess_path) + + train_features, train_labels, train_weight_np = parser.get_datas('train') + _convert_to_mindrecord(preprocess_path, train_features, train_labels, train_weight_np) + + test_features, test_labels, _ = parser.get_datas('test') + _convert_to_mindrecord(preprocess_path, test_features, test_labels, training=False) diff --git a/example/lstm_aclImdb/eval.py b/example/lstm_aclImdb/eval.py new file mode 100644 index 0000000000..e76d40ac67 --- /dev/null +++ b/example/lstm_aclImdb/eval.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train lstm example on aclImdb######################## +python eval.py --ckpt_path=./lstm-20-390.ckpt +""" +import argparse +import os + +import numpy as np + +from config import lstm_cfg as cfg +from dataset import create_dataset, convert_to_mindrecord +from mindspore import Tensor, nn, Model, context +from mindspore.model_zoo.lstm import SentimentNet +from mindspore.nn import Accuracy +from mindspore.train.callback import LossMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='MindSpore LSTM Example') + parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'], + help='whether to preprocess data.') + parser.add_argument('--aclimdb_path', type=str, default="./aclImdb", + help='path where the dataset is stored.') + parser.add_argument('--glove_path', type=str, default="./glove", + help='path where the GloVe is stored.') + parser.add_argument('--preprocess_path', type=str, default="./preprocess", + help='path where the pre-process data is stored.') + parser.add_argument('--ckpt_path', type=str, default=None, + help='the checkpoint file path used to evaluate model.') + parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], + help='the target device to run, support "GPU", "CPU". Default: "GPU".') + args = parser.parse_args() + + context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target=args.device_target) + + if args.preprocess == "true": + print("============== Starting Data Pre-processing ==============") + convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) + + embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) + network = SentimentNet(vocab_size=embedding_table.shape[0], + embed_size=cfg.embed_size, + num_hiddens=cfg.num_hiddens, + num_layers=cfg.num_layers, + bidirectional=cfg.bidirectional, + num_classes=cfg.num_classes, + weight=Tensor(embedding_table), + batch_size=cfg.batch_size) + + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) + loss_cb = LossMonitor() + + model = Model(network, loss, opt, {'acc': Accuracy()}) + + print("============== Starting Testing ==============") + ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, training=False) + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(network, param_dict) + if args.device_target == "CPU": + acc = model.eval(ds_eval, dataset_sink_mode=False) + else: + acc = model.eval(ds_eval) + print("============== Accuracy:{} ==============".format(acc)) diff --git a/example/lstm_aclImdb/imdb.py b/example/lstm_aclImdb/imdb.py new file mode 100644 index 0000000000..66d04f1281 --- /dev/null +++ b/example/lstm_aclImdb/imdb.py @@ -0,0 +1,155 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +imdb dataset parser. +""" +import os +from itertools import chain + +import gensim +import numpy as np + + +class ImdbParser(): + """ + parse aclImdb data to features and labels. + sentence->tokenized->encoded->padding->features + """ + + def __init__(self, imdb_path, glove_path, embed_size=300): + self.__segs = ['train', 'test'] + self.__label_dic = {'pos': 1, 'neg': 0} + self.__imdb_path = imdb_path + self.__glove_dim = embed_size + self.__glove_file = os.path.join(glove_path, 'glove.6B.' + str(self.__glove_dim) + 'd.txt') + + # properties + self.__imdb_datas = {} + self.__features = {} + self.__labels = {} + self.__vacab = {} + self.__word2idx = {} + self.__weight_np = {} + self.__wvmodel = None + + def parse(self): + """ + parse imdb data to memory + """ + self.__wvmodel = gensim.models.KeyedVectors.load_word2vec_format(self.__glove_file) + + for seg in self.__segs: + self.__parse_imdb_datas(seg) + self.__parse_features_and_labels(seg) + self.__gen_weight_np(seg) + + def __parse_imdb_datas(self, seg): + """ + load data from txt + """ + data_lists = [] + for label_name, label_id in self.__label_dic.items(): + sentence_dir = os.path.join(self.__imdb_path, seg, label_name) + for file in os.listdir(sentence_dir): + with open(os.path.join(sentence_dir, file), mode='r', encoding='utf8') as f: + sentence = f.read().replace('\n', '') + data_lists.append([sentence, label_id]) + self.__imdb_datas[seg] = data_lists + + def __parse_features_and_labels(self, seg): + """ + parse features and labels + """ + features = [] + labels = [] + for sentence, label in self.__imdb_datas[seg]: + features.append(sentence) + labels.append(label) + + self.__features[seg] = features + self.__labels[seg] = labels + + # update feature to tokenized + self.__updata_features_to_tokenized(seg) + # parse vacab + self.__parse_vacab(seg) + # encode feature + self.__encode_features(seg) + # padding feature + self.__padding_features(seg) + + def __updata_features_to_tokenized(self, seg): + tokenized_features = [] + for sentence in self.__features[seg]: + tokenized_sentence = [word.lower() for word in sentence.split(" ")] + tokenized_features.append(tokenized_sentence) + self.__features[seg] = tokenized_features + + def __parse_vacab(self, seg): + # vocab + tokenized_features = self.__features[seg] + vocab = set(chain(*tokenized_features)) + self.__vacab[seg] = vocab + + # word_to_idx: {'hello': 1, 'world':111, ... '': 0} + word_to_idx = {word: i + 1 for i, word in enumerate(vocab)} + word_to_idx[''] = 0 + self.__word2idx[seg] = word_to_idx + + def __encode_features(self, seg): + """ encode word to index """ + word_to_idx = self.__word2idx['train'] + encoded_features = [] + for tokenized_sentence in self.__features[seg]: + encoded_sentence = [] + for word in tokenized_sentence: + encoded_sentence.append(word_to_idx.get(word, 0)) + encoded_features.append(encoded_sentence) + self.__features[seg] = encoded_features + + def __padding_features(self, seg, maxlen=500, pad=0): + """ pad all features to the same length """ + padded_features = [] + for feature in self.__features[seg]: + if len(feature) >= maxlen: + padded_feature = feature[:maxlen] + else: + padded_feature = feature + while len(padded_feature) < maxlen: + padded_feature.append(pad) + padded_features.append(padded_feature) + self.__features[seg] = padded_features + + def __gen_weight_np(self, seg): + """ + generate weight by gensim + """ + weight_np = np.zeros((len(self.__word2idx[seg]), self.__glove_dim), dtype=np.float32) + for word, idx in self.__word2idx[seg].items(): + if word not in self.__wvmodel: + continue + word_vector = self.__wvmodel.get_vector(word) + weight_np[idx, :] = word_vector + + self.__weight_np[seg] = weight_np + + def get_datas(self, seg): + """ + return features, labels, and weight + """ + features = np.array(self.__features[seg]).astype(np.int32) + labels = np.array(self.__labels[seg]).astype(np.int32) + weight = np.array(self.__weight_np[seg]) + return features, labels, weight diff --git a/example/lstm_aclImdb/train.py b/example/lstm_aclImdb/train.py new file mode 100644 index 0000000000..3d1c670f4e --- /dev/null +++ b/example/lstm_aclImdb/train.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train lstm example on aclImdb######################## +python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path +""" +import argparse +import os + +import numpy as np + +from config import lstm_cfg as cfg +from dataset import convert_to_mindrecord +from dataset import create_dataset +from mindspore import Tensor, nn, Model, context +from mindspore.model_zoo.lstm import SentimentNet +from mindspore.nn import Accuracy +from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='MindSpore LSTM Example') + parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'], + help='whether to preprocess data.') + parser.add_argument('--aclimdb_path', type=str, default="./aclImdb", + help='path where the dataset is stored.') + parser.add_argument('--glove_path', type=str, default="./glove", + help='path where the GloVe is stored.') + parser.add_argument('--preprocess_path', type=str, default="./preprocess", + help='path where the pre-process data is stored.') + parser.add_argument('--ckpt_path', type=str, default="./", + help='the path to save the checkpoint file.') + parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], + help='the target device to run, support "GPU", "CPU". Default: "GPU".') + args = parser.parse_args() + + context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target=args.device_target) + + if args.preprocess == "true": + print("============== Starting Data Pre-processing ==============") + convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) + + embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) + network = SentimentNet(vocab_size=embedding_table.shape[0], + embed_size=cfg.embed_size, + num_hiddens=cfg.num_hiddens, + num_layers=cfg.num_layers, + bidirectional=cfg.bidirectional, + num_classes=cfg.num_classes, + weight=Tensor(embedding_table), + batch_size=cfg.batch_size) + + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) + loss_cb = LossMonitor() + + model = Model(network, loss, opt, {'acc': Accuracy()}) + + print("============== Starting Training ==============") + ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs) + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + if args.device_target == "CPU": + model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False) + else: + model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb]) + print("============== Training Success ==============") diff --git a/mindspore/model_zoo/lstm.py b/mindspore/model_zoo/lstm.py new file mode 100644 index 0000000000..35fe674303 --- /dev/null +++ b/mindspore/model_zoo/lstm.py @@ -0,0 +1,115 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LSTM.""" +import math + +import numpy as np + +from mindspore import Parameter, Tensor, nn +from mindspore.common.initializer import initializer +from mindspore.ops import operations as P + + +def init_lstm_weight( + input_size, + hidden_size, + num_layers, + bidirectional, + has_bias=True): + """Initialize lstm weight.""" + num_directions = 1 + if bidirectional: + num_directions = 2 + + weight_size = 0 + gate_size = 4 * hidden_size + for layer in range(num_layers): + for _ in range(num_directions): + input_layer_size = input_size if layer == 0 else hidden_size * num_directions + weight_size += gate_size * input_layer_size + weight_size += gate_size * hidden_size + if has_bias: + weight_size += 2 * gate_size + + stdv = 1 / math.sqrt(hidden_size) + w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) + w = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') + + return w + + +# Initialize short-term memory (h) and long-term memory (c) to 0 +def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): + """init default input.""" + num_directions = 1 + if bidirectional: + num_directions = 2 + + h = Tensor( + np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) + c = Tensor( + np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) + return h, c + + +class SentimentNet(nn.Cell): + """Sentiment network structure.""" + + def __init__(self, + vocab_size, + embed_size, + num_hiddens, + num_layers, + bidirectional, + num_classes, + weight, + batch_size): + super(SentimentNet, self).__init__() + # Mapp words to vectors + self.embedding = nn.Embedding(vocab_size, + embed_size, + embedding_table=weight) + self.embedding.embedding_table.requires_grad = False + self.trans = P.Transpose() + self.perm = (1, 0, 2) + self.encoder = nn.LSTM(input_size=embed_size, + hidden_size=num_hiddens, + num_layers=num_layers, + has_bias=True, + bidirectional=bidirectional, + dropout=0.0) + w_init = init_lstm_weight( + embed_size, + num_hiddens, + num_layers, + bidirectional) + self.encoder.weight = w_init + self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) + + self.concat = P.Concat(1) + if bidirectional: + self.decoder = nn.Dense(num_hiddens * 4, num_classes) + else: + self.decoder = nn.Dense(num_hiddens * 2, num_classes) + + def construct(self, inputs): + # input:(64,500,300) + embeddings = self.embedding(inputs) + embeddings = self.trans(embeddings, self.perm) + output, _ = self.encoder(embeddings, (self.h, self.c)) + # states[i] size(64,200) -> encoding.size(64,400) + encoding = self.concat((output[0], output[1])) + outputs = self.decoder(encoding) + return outputs