forked from mindspore-Ecosystem/mindspore
!21501 bug fix for protonet
Merge pull request !21501 from wukesong/wl_master_protonet
This commit is contained in:
commit
7c2fcfe1ee
|
@ -29,7 +29,12 @@ Proto-Net contains 2 parts named Encoder and Relation. The former one has 4 conv
|
|||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
|
||||
The dataset omniglot can be obtained from (https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/master/). You can obtain the dataset after running the scripts.
|
||||
|
||||
```bash
|
||||
cd src
|
||||
python train.py
|
||||
```
|
||||
|
||||
- Dataset size 4.02M,32462 28*28 in 1622 classes
|
||||
- Train 1,200 classes
|
||||
|
@ -39,7 +44,7 @@ Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
|
|||
|
||||
- The directory structure is as follows:
|
||||
|
||||
```text
|
||||
```shell
|
||||
└─Data
|
||||
├─raw
|
||||
├─spilts
|
||||
|
@ -67,13 +72,13 @@ Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
|
|||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```shell
|
||||
# enter script dir, train ProtoNet in standalone
|
||||
sh run_standalone_train_ascend.sh dataset 1 20 20
|
||||
# enter script dir, train ProtoNet in distribution
|
||||
sh run_distribution_ascend.sh dataset rank_table dataset 20
|
||||
```python
|
||||
# enter script dir, train ProtoNet
|
||||
sh run_standalone_train_ascend.sh "../dataset" 1 60 500
|
||||
# enter script dir, evaluate ProtoNet
|
||||
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
|
||||
sh run_standalone_eval_ascend.sh "../dataset" "./output/best_ck.ckpt" 1 5
|
||||
# enter script dir, train ProtoNet distributed
|
||||
sh run_distribution_ascend.sh "./rank_table.json" "../dataset" 60 500
|
||||
```
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
@ -120,8 +125,7 @@ Major parameters in train.py and config.py as follows:
|
|||
### Training
|
||||
|
||||
```bash
|
||||
# enter script dir, train ProtoNet in standalone
|
||||
sh run_standalone_train_ascend.sh dataset 1 20 20
|
||||
sh run_standalone_train_ascend.sh "../dataset" 1 60 500
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the current directory.
|
||||
|
@ -133,11 +137,11 @@ The model checkpoint will be saved in the current directory.
|
|||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
```bash
|
||||
# enter script dir, evaluate ProtoNet
|
||||
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
|
||||
sh run_standalone_eval_ascend.sh "../dataset" "./output/best_ck.ckpt" 1 5
|
||||
```
|
||||
|
||||
```text
|
||||
```shell
|
||||
|
||||
Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
|
||||
```
|
||||
|
||||
|
@ -149,9 +153,9 @@ Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
|
|||
|
||||
| Parameters | ProtoNet |
|
||||
| -------------------------- | ---------------------------------------------------------- |
|
||||
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz; 192cores; Memory 755G |
|
||||
| Resource | CentOs 8.2; Ascend 910 ; CPU 2.60GHz,192cores;Memory 755G |
|
||||
| uploaded Date | 03/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | OMNIGLOT |
|
||||
| Training Parameters | episode=500, class_num = 5, lr=0.001, classes_per_it_tr=60, num_support_tr=5, num_query_tr=5, classes_per_it_val=20, num_support_val=5, num_query_val=15 |
|
||||
| Optimizer | Adam |
|
||||
|
@ -161,7 +165,7 @@ Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
|
|||
| Speed | 215 ms/step |
|
||||
| Total time | 3 h 23m (8p) |
|
||||
| Checkpoint for Fine tuning | 440 KB (.ckpt file) |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ProtoNet |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/research/cv/protonet |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
|
|
|
@ -15,14 +15,13 @@
|
|||
"""
|
||||
ProtoNet evaluation script.
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
from mindspore import dataset as ds
|
||||
from mindspore import load_checkpoint
|
||||
import mindspore.context as context
|
||||
from src.protonet import ProtoNet
|
||||
from src.parser_util import get_parser
|
||||
from src.PrototypicalLoss import PrototypicalLoss
|
||||
import numpy as np
|
||||
from model_init import init_dataloader
|
||||
from train import WithLossCell
|
||||
|
||||
|
@ -67,5 +66,5 @@ if __name__ == '__main__':
|
|||
options.classes_per_it_val, is_train=False)
|
||||
Net = WithLossCell(Net, loss_fn)
|
||||
val_dataloader = init_dataloader(options, 'val', datapath)
|
||||
load_checkpoint(os.path.join(ckptpath, 'best_ck.ckpt'), net=Net)
|
||||
load_checkpoint(ckptpath, net=Net)
|
||||
test(val_dataloader, Net)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# an simple tutorial as follows, more parameters can be setting
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [TRAIN_CLASS]"
|
||||
echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [TRAIN_CLASS] [EPOCHS]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -33,6 +33,7 @@ RANK_TABLE_FILE=$(realpath $1)
|
|||
export RANK_TABLE_FILE
|
||||
export DATA_PATH=$2
|
||||
export TRAIN_CLASS=$3
|
||||
export EPOCHS=$4
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
|
@ -43,13 +44,16 @@ do
|
|||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ../train.py ./train_parallel$i
|
||||
cp ../model_init.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --data_path=$DATA_PATH \
|
||||
python train.py --dataset_root=$DATA_PATH \
|
||||
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||
--classes_per_it_tr=$TRAIN_CLASS > log 2>&1 &
|
||||
--classes_per_it_tr=$TRAIN_CLASS\
|
||||
--experiment_root=./output\
|
||||
--epochs=$EPOCHS > log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -49,7 +49,7 @@ def get_parser():
|
|||
parser.add_argument('-exp', '--experiment_root',
|
||||
type=str,
|
||||
help='root where to store models, losses and accuracies',
|
||||
default='..' + os.sep + 'output')
|
||||
default='.' + os.sep + 'output')
|
||||
|
||||
parser.add_argument('-nep', '--epochs',
|
||||
type=int,
|
||||
|
|
Loading…
Reference in New Issue