gru_fix_bug
This commit is contained in:
parent
78c733ffbe
commit
6adb82c2c6
|
@ -1,4 +1,4 @@
|
|||
data:image/s3,"s3://crabby-images/7a9b9/7a9b9296f7e0e122a066854a4dcc81721e8a8e2b" alt=""
|
||||
data:image/s3,"s3://crabby-images/6092b/6092b1ecfe2525404c521c212f13a351fe833785" alt=""
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -52,6 +52,26 @@ In this model, we use the Multi30K dataset as our train and test dataset.As trai
|
|||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## Requirements
|
||||
|
||||
```txt
|
||||
nltk
|
||||
numpy
|
||||
```
|
||||
|
||||
To install nltk, you should install nltk as follow:
|
||||
|
||||
```bash
|
||||
pip install nltk
|
||||
```
|
||||
|
||||
Then you should download extra packages as follow:
|
||||
|
||||
```python
|
||||
import nltk
|
||||
nltk.download()
|
||||
```
|
||||
|
||||
# [Quick Start](#content)
|
||||
|
||||
After dataset preparation, you can start training and evaluation as follows:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer evaluation script."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -41,8 +41,13 @@ def run_gru_eval():
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
|
||||
device_id=args.device_id, save_graphs=False)
|
||||
prefix = "multi30k_test_mindrecord_32"
|
||||
mindrecord_file = os.path.join(args.dataset_path, prefix)
|
||||
if not os.path.exists(mindrecord_file):
|
||||
print("dataset file {} not exists, please check!".format(mindrecord_file))
|
||||
raise ValueError(mindrecord_file)
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
|
||||
dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
||||
dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("dataset size is {}".format(dataset_size))
|
||||
network = Seq2Seq(config, is_training=False)
|
||||
|
|
|
@ -40,9 +40,9 @@ fi
|
|||
DATASET_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
@ -41,9 +41,9 @@ fi
|
|||
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
rm -rf ./eval
|
||||
|
|
|
@ -33,9 +33,9 @@ get_real_path(){
|
|||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
echo $DATASET_PATH
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
@ -99,8 +99,13 @@ if __name__ == '__main__':
|
|||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
prefix = "multi30k_train_mindrecord_32_"
|
||||
mindrecord_file = os.path.join(args.dataset_path, prefix+"0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
print("dataset file {} not exists, please check!".format(mindrecord_file))
|
||||
raise ValueError(mindrecord_file)
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size,
|
||||
dataset_path=args.dataset_path, rank_size=device_num, rank_id=rank)
|
||||
dataset_path=mindrecord_file, rank_size=device_num, rank_id=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("dataset size is {}".format(dataset_size))
|
||||
network = Seq2Seq(config)
|
||||
|
|
Loading…
Reference in New Issue