!21501 bug fix for protonet

Merge pull request !21501 from wukesong/wl_master_protonet
This commit is contained in:
i-robot 2021-08-09 01:26:03 +00:00 committed by Gitee
commit 7c2fcfe1ee
4 changed files with 32 additions and 25 deletions

View File

@ -29,7 +29,12 @@ Proto-Net contains 2 parts named Encoder and Relation. The former one has 4 conv
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
The dataset omniglot can be obtained from (https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/master/). You can obtain the dataset after running the scripts.
```bash
cd src
python train.py
```
- Dataset size 4.02M32462 28*28 in 1622 classes
- Train 1,200 classes
@ -39,7 +44,7 @@ Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
- The directory structure is as follows:
```text
```shell
└─Data
├─raw
├─spilts
@ -67,13 +72,13 @@ Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```shell
# enter script dir, train ProtoNet in standalone
sh run_standalone_train_ascend.sh dataset 1 20 20
# enter script dir, train ProtoNet in distribution
sh run_distribution_ascend.sh dataset rank_table dataset 20
```python
# enter script dir, train ProtoNet
sh run_standalone_train_ascend.sh "../dataset" 1 60 500
# enter script dir, evaluate ProtoNet
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
sh run_standalone_eval_ascend.sh "../dataset" "./output/best_ck.ckpt" 1 5
# enter script dir, train ProtoNet distributed
sh run_distribution_ascend.sh "./rank_table.json" "../dataset" 60 500
```
## [Script and Sample Code](#contents)
@ -120,8 +125,7 @@ Major parameters in train.py and config.py as follows:
### Training
```bash
# enter script dir, train ProtoNet in standalone
sh run_standalone_train_ascend.sh dataset 1 20 20
sh run_standalone_train_ascend.sh "../dataset" 1 60 500
```
The model checkpoint will be saved in the current directory.
@ -133,11 +137,11 @@ The model checkpoint will be saved in the current directory.
Before running the command below, please check the checkpoint path used for evaluation.
```bash
# enter script dir, evaluate ProtoNet
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
sh run_standalone_eval_ascend.sh "../dataset" "./output/best_ck.ckpt" 1 5
```
```text
```shell
Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
```
@ -149,9 +153,9 @@ Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
| Parameters | ProtoNet |
| -------------------------- | ---------------------------------------------------------- |
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz; 192cores; Memory 755G |
| Resource | CentOs 8.2; Ascend 910 ; CPU 2.60GHz192coresMemory 755G |
| uploaded Date | 03/26/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| MindSpore Version | 1.1.1 |
| Dataset | OMNIGLOT |
| Training Parameters | episode=500, class_num = 5, lr=0.001, classes_per_it_tr=60, num_support_tr=5, num_query_tr=5, classes_per_it_val=20, num_support_val=5, num_query_val=15 |
| Optimizer | Adam |
@ -161,7 +165,7 @@ Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
| Speed | 215 ms/step |
| Total time | 3 h 23m (8p) |
| Checkpoint for Fine tuning | 440 KB (.ckpt file) |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ProtoNet |
| Scripts | https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/research/cv/protonet |
# [ModelZoo Homepage](#contents)

View File

@ -15,14 +15,13 @@
"""
ProtoNet evaluation script.
"""
import os
import numpy as np
from mindspore import dataset as ds
from mindspore import load_checkpoint
import mindspore.context as context
from src.protonet import ProtoNet
from src.parser_util import get_parser
from src.PrototypicalLoss import PrototypicalLoss
import numpy as np
from model_init import init_dataloader
from train import WithLossCell
@ -67,5 +66,5 @@ if __name__ == '__main__':
options.classes_per_it_val, is_train=False)
Net = WithLossCell(Net, loss_fn)
val_dataloader = init_dataloader(options, 'val', datapath)
load_checkpoint(os.path.join(ckptpath, 'best_ck.ckpt'), net=Net)
load_checkpoint(ckptpath, net=Net)
test(val_dataloader, Net)

View File

@ -16,7 +16,7 @@
# an simple tutorial as follows, more parameters can be setting
if [ $# != 4 ]
then
echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [TRAIN_CLASS]"
echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [TRAIN_CLASS] [EPOCHS]"
exit 1
fi
@ -33,6 +33,7 @@ RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export DATA_PATH=$2
export TRAIN_CLASS=$3
export EPOCHS=$4
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
export SERVER_ID=0
@ -43,13 +44,16 @@ do
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ../train.py ./train_parallel$i
cp ../model_init.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
python train.py --data_path=$DATA_PATH \
python train.py --dataset_root=$DATA_PATH \
--device_id=$DEVICE_ID --device_target="Ascend" \
--classes_per_it_tr=$TRAIN_CLASS > log 2>&1 &
--classes_per_it_tr=$TRAIN_CLASS\
--experiment_root=./output\
--epochs=$EPOCHS > log 2>&1 &
cd ..
done

View File

@ -49,7 +49,7 @@ def get_parser():
parser.add_argument('-exp', '--experiment_root',
type=str,
help='root where to store models, losses and accuracies',
default='..' + os.sep + 'output')
default='.' + os.sep + 'output')
parser.add_argument('-nep', '--epochs',
type=int,