add protonet1
This commit is contained in:
parent
aaea093af9
commit
c591801473
|
@ -0,0 +1,168 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [Prototypical-Network Description](#protonet-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Evaluation Performance](#evaluation-performance)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
# [protonet-Description](#contents)
|
||||||
|
|
||||||
|
PyTorch code for NeuralIPS 2017 paper: [Prototypical Networks for Few-shot Learning](https://arxiv.org/abs/1703.05175)
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
Proto-Net contains 2 parts named Encoder and Relation. The former one has 4 convolution layers, the latter one has 2 convolution layers and 2 linear layers.
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||||
|
|
||||||
|
Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
|
||||||
|
|
||||||
|
- Dataset size 4.02M,32462 28*28 in 1622 classes
|
||||||
|
- Train 1,200 classes
|
||||||
|
- Test 422 classes
|
||||||
|
- Data format .png files
|
||||||
|
- Note Data has been processed in omniglot_resized
|
||||||
|
|
||||||
|
- The directory structure is as follows:
|
||||||
|
|
||||||
|
```text
|
||||||
|
└─Data
|
||||||
|
├─raw
|
||||||
|
├─spilts
|
||||||
|
│ vinyals
|
||||||
|
│ test.txt
|
||||||
|
│ train.txt
|
||||||
|
│ val.txt
|
||||||
|
│ trainval.txt
|
||||||
|
└─data
|
||||||
|
Alphabet_of_the_Magi
|
||||||
|
Angelic
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware(Ascend)
|
||||||
|
- Prepare hardware environment with Ascend.
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# enter script dir, train ProtoNet in standalone
|
||||||
|
sh run_standalone_train_ascend.sh dataset 1 20 20
|
||||||
|
# enter script dir, train ProtoNet in distribution
|
||||||
|
sh run_distribution_ascend.sh dataset rank_table dataset 20
|
||||||
|
# enter script dir, evaluate ProtoNet
|
||||||
|
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
├── cv
|
||||||
|
├── ProtoNet
|
||||||
|
├── requirements.txt
|
||||||
|
├── README.md // descriptions about lenet
|
||||||
|
├── scripts
|
||||||
|
│ ├──run_standalone_train_ascend.sh // train in ascend
|
||||||
|
│ ├──run_standalone_eval_ascend.sh // evaluate in ascend
|
||||||
|
│ ├──run_distribution_ascend.sh // distribution in ascend
|
||||||
|
├── src
|
||||||
|
│ ├──parser_util.py // parameter configuration
|
||||||
|
│ ├──dataset.py // creating dataset
|
||||||
|
│ ├──IterDatasetGenerator.py // generate dataset
|
||||||
|
│ ├──protonet.py // relationnet architecture
|
||||||
|
│ ├──PrototypicalLoss.py // loss function
|
||||||
|
├── train.py // training script
|
||||||
|
├── eval.py // evaluation script
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
```python
|
||||||
|
Major parameters in train.py and config.py as follows:
|
||||||
|
|
||||||
|
--class_num: the number of class we use in one step.
|
||||||
|
--sample_num_per_class: the number of quert data we extract from one class.
|
||||||
|
--batch_num_per_class: the number of support data we extract from one class.
|
||||||
|
--data_path: The absolute full path to the train and evaluation datasets.
|
||||||
|
--episode: Total training epochs.
|
||||||
|
--test_episode: Total testing episodes
|
||||||
|
--learning_rate: Learning rate
|
||||||
|
--device_target: Device where the code will be implemented.
|
||||||
|
--save_dir: The absolute full path to the checkpoint file saved
|
||||||
|
after training.
|
||||||
|
--data_path: Path where the dataset is saved
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Training Process](#contents)
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# enter script dir, train ProtoNet in standalone
|
||||||
|
sh run_standalone_train_ascend.sh dataset 1 20 20
|
||||||
|
```
|
||||||
|
|
||||||
|
The model checkpoint will be saved in the current directory.
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
Before running the command below, please check the checkpoint path used for evaluation.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# enter script dir, evaluate ProtoNet
|
||||||
|
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
|
||||||
|
```
|
||||||
|
|
||||||
|
```text
|
||||||
|
Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Evaluation Performance
|
||||||
|
|
||||||
|
| Parameters | ProtoNet |
|
||||||
|
| -------------------------- | ---------------------------------------------------------- |
|
||||||
|
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz; 192cores; Memory 755G |
|
||||||
|
| uploaded Date | 03/26/2021 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.2.0 |
|
||||||
|
| Dataset | OMNIGLOT |
|
||||||
|
| Training Parameters | episode=500, class_num = 5, lr=0.001, classes_per_it_tr=60, num_support_tr=5, num_query_tr=5, classes_per_it_val=20, num_support_val=5, num_query_val=15 |
|
||||||
|
| Optimizer | Adam |
|
||||||
|
| Loss Function | Prototypicalloss |
|
||||||
|
| outputs | Accuracy |
|
||||||
|
| Loss | 0.002 |
|
||||||
|
| Speed | 215 ms/step |
|
||||||
|
| Total time | 3 h 23m (8p) |
|
||||||
|
| Checkpoint for Fine tuning | 440 KB (.ckpt file) |
|
||||||
|
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/protonet |
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
ProtoNet evaluation script.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from mindspore import dataset as ds
|
||||||
|
from mindspore import load_checkpoint
|
||||||
|
import mindspore.context as context
|
||||||
|
from src.protonet import ProtoNet
|
||||||
|
from src.parser_util import get_parser
|
||||||
|
from src.PrototypicalLoss import PrototypicalLoss
|
||||||
|
import numpy as np
|
||||||
|
from model_init import init_dataloader
|
||||||
|
from train import WithLossCell
|
||||||
|
|
||||||
|
|
||||||
|
def test(test_dataloader, net):
|
||||||
|
"""
|
||||||
|
test function
|
||||||
|
"""
|
||||||
|
inp = ds.GeneratorDataset(test_dataloader, column_names=['data', 'label', 'classes'])
|
||||||
|
avg_acc = list()
|
||||||
|
avg_loss = list()
|
||||||
|
for _ in range(10):
|
||||||
|
i = 0
|
||||||
|
for batch in inp.create_dict_iterator():
|
||||||
|
i = i + 1
|
||||||
|
print(i)
|
||||||
|
x = batch['data']
|
||||||
|
y = batch['label']
|
||||||
|
classes = batch['classes']
|
||||||
|
acc, loss = net(x, y, classes)
|
||||||
|
avg_acc.append(acc.asnumpy())
|
||||||
|
avg_loss.append(loss.asnumpy())
|
||||||
|
print('eval end')
|
||||||
|
avg_acc = np.mean(avg_acc)
|
||||||
|
avg_loss = np.mean(avg_loss)
|
||||||
|
print('Test Acc: {} Loss: {}'.format(avg_acc, avg_loss))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
options = get_parser().parse_args()
|
||||||
|
if options.run_offline:
|
||||||
|
datapath = options.dataset_root
|
||||||
|
ckptpath = options.experiment_root
|
||||||
|
else:
|
||||||
|
import mox
|
||||||
|
mox.file.copy_parallel(src_url=options.data_url, dst_url='cache/data')
|
||||||
|
mox.file.copy_parallel(src_url=options.ckpt_url, dst_url='cache/ckpt')
|
||||||
|
datapath = 'cache/data'
|
||||||
|
ckptpath = 'cache/ckpt'
|
||||||
|
Net = ProtoNet()
|
||||||
|
loss_fn = PrototypicalLoss(options.num_support_val, options.num_query_val,
|
||||||
|
options.classes_per_it_val, is_train=False)
|
||||||
|
Net = WithLossCell(Net, loss_fn)
|
||||||
|
val_dataloader = init_dataloader(options, 'val', datapath)
|
||||||
|
load_checkpoint(os.path.join(ckptpath, 'best_ck.ckpt'), net=Net)
|
||||||
|
test(val_dataloader, Net)
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""export checkpoint file into air, onnx, mindir models"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||||
|
|
||||||
|
from src.protonet import ProtoNet as protonet
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||||
|
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||||
|
parser.add_argument("--file_name", type=str, default="protonet", help="output file name.")
|
||||||
|
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||||
|
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||||
|
help="device target")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# define fusion network
|
||||||
|
network = protonet()
|
||||||
|
# load network checkpoint
|
||||||
|
param_dict = load_checkpoint(args.ckpt_file)
|
||||||
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
|
# export network
|
||||||
|
inputs = Tensor(np.ones([args.batch_size, 1, 28, 28]), mindspore.float32)
|
||||||
|
export(network, inputs, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
ProtoNet model init script.
|
||||||
|
"""
|
||||||
|
import itertools
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from src.dataset import OmniglotDataset
|
||||||
|
from src.IterDatasetGenerator import IterDatasetGenerator
|
||||||
|
|
||||||
|
def init_lr_scheduler(opt):
|
||||||
|
'''
|
||||||
|
Initialize the learning rate scheduler
|
||||||
|
'''
|
||||||
|
epochs = opt.epochs
|
||||||
|
milestone = list(itertools.takewhile(lambda n: n < epochs, itertools.count(1, opt.lr_scheduler_step)))
|
||||||
|
|
||||||
|
lr0 = opt.learning_rate
|
||||||
|
bl = list(np.logspace(0, len(milestone)-1, len(milestone), base=opt.lr_scheduler_gamma))
|
||||||
|
lr = [lr0*b for b in bl]
|
||||||
|
lr_epoch = nn.piecewise_constant_lr(milestone, lr)
|
||||||
|
return lr_epoch
|
||||||
|
|
||||||
|
def init_dataset(opt, mode, path):
|
||||||
|
'''
|
||||||
|
Initialize the dataset
|
||||||
|
'''
|
||||||
|
dataset = OmniglotDataset(mode=mode, root=path)
|
||||||
|
n_classes = len(np.unique(dataset.y))
|
||||||
|
if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:
|
||||||
|
raise(Exception('There are not enough classes in the dataset in order ' +
|
||||||
|
'to satisfy the chosen classes_per_it. Decrease the ' +
|
||||||
|
'classes_per_it_{tr/val} option and try again.'))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def init_dataloader(opt, mode, path):
|
||||||
|
'''
|
||||||
|
Initialize the dataloader
|
||||||
|
'''
|
||||||
|
dataset = init_dataset(opt, mode, path)
|
||||||
|
if 'train' in mode:
|
||||||
|
classes_per_it = opt.classes_per_it_tr
|
||||||
|
num_samples = opt.num_support_tr + opt.num_query_tr
|
||||||
|
|
||||||
|
else:
|
||||||
|
classes_per_it = opt.classes_per_it_val
|
||||||
|
num_samples = opt.num_support_val + opt.num_query_val
|
||||||
|
|
||||||
|
dataloader = IterDatasetGenerator(dataset, classes_per_it, num_samples, opt.iterations)
|
||||||
|
return dataloader
|
|
@ -0,0 +1,3 @@
|
||||||
|
numpy >= 1.17.0
|
||||||
|
tqdm>= 4.61.0
|
||||||
|
pillow >= 8.2.0
|
|
@ -0,0 +1,55 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
# an simple tutorial as follows, more parameters can be setting
|
||||||
|
if [ $# != 4 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [TRAIN_CLASS]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $1 ]
|
||||||
|
then
|
||||||
|
echo "error: RANK_TABLE_FILE=$1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
export RANK_SIZE=8
|
||||||
|
RANK_TABLE_FILE=$(realpath $1)
|
||||||
|
export RANK_TABLE_FILE
|
||||||
|
export DATA_PATH=$2
|
||||||
|
export TRAIN_CLASS=$3
|
||||||
|
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||||
|
|
||||||
|
export SERVER_ID=0
|
||||||
|
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||||
|
for((i=0; i<${DEVICE_NUM}; i++))
|
||||||
|
do
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$((rank_start + i))
|
||||||
|
rm -rf ./train_parallel$i
|
||||||
|
mkdir ./train_parallel$i
|
||||||
|
cp -r ./src ./train_parallel$i
|
||||||
|
cp ./train.py ./train_parallel$i
|
||||||
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
|
cd ./train_parallel$i ||exit
|
||||||
|
env > env.log
|
||||||
|
python train.py --data_path=$DATA_PATH \
|
||||||
|
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||||
|
--classes_per_it_tr=$TRAIN_CLASS > log 2>&1 &
|
||||||
|
cd ..
|
||||||
|
done
|
|
@ -0,0 +1,30 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
# an simple tutorial as follows, more parameters can be setting
|
||||||
|
if [ $# != 4 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [EVAL_CLASS]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DATA_PATH=$1
|
||||||
|
export CKPT_PATH=$2
|
||||||
|
export DEVICE_ID=$3
|
||||||
|
export EVAL_CLASS=$4
|
||||||
|
|
||||||
|
python ../eval.py --dataset_root=$DATA_PATH --experiment_root=$CKPT_PATH \
|
||||||
|
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||||
|
--classes_per_it_val=$EVAL_CLASS > eval_log 2>&1 &
|
|
@ -0,0 +1,31 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
# an simple tutorial as follows, more parameters can be setting
|
||||||
|
if [ $# != 4 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_standalone_train_ascend.sh [DATA_PATH] [DEVICE_ID] [TRAIN_CLASS] [EPOCHS]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DATA_PATH=$1
|
||||||
|
export DEVICE_ID=$2
|
||||||
|
export TRAIN_CLASS=$3
|
||||||
|
export EPOCHS=$4
|
||||||
|
|
||||||
|
python ../train.py --dataset_root=$DATA_PATH \
|
||||||
|
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||||
|
--classes_per_it_tr=$TRAIN_CLASS \
|
||||||
|
--epochs=$EPOCHS > log 2>&1 &
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Callback for eval
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
from mindspore import save_checkpoint
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class EvalCallBack(Callback):
|
||||||
|
"""
|
||||||
|
CallBack class
|
||||||
|
"""
|
||||||
|
def __init__(self, options, net, eval_dataset, path):
|
||||||
|
self.net = net
|
||||||
|
self.eval_dataset = eval_dataset
|
||||||
|
self.path = path
|
||||||
|
self.avgacc = 0
|
||||||
|
self.avgloss = 0
|
||||||
|
self.bestacc = 0
|
||||||
|
self.options = options
|
||||||
|
|
||||||
|
|
||||||
|
def epoch_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
CallBack epoch begin
|
||||||
|
"""
|
||||||
|
cb_param = run_context.original_args()
|
||||||
|
cur_epoch = cb_param.cur_epoch_num
|
||||||
|
print('=========EPOCH {} BEGIN========='.format(cur_epoch))
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""
|
||||||
|
CallBack epoch end
|
||||||
|
"""
|
||||||
|
cb_param = run_context.original_args()
|
||||||
|
cur_epoch = cb_param.cur_epoch_num
|
||||||
|
cur_net = cb_param.network
|
||||||
|
# print(cur_net)
|
||||||
|
evalnet = self.net
|
||||||
|
self.avgacc, self.avgloss = self.eval(self.eval_dataset, evalnet)
|
||||||
|
|
||||||
|
if self.avgacc > self.bestacc:
|
||||||
|
self.bestacc = self.avgacc
|
||||||
|
print('Epoch {}: Avg Accuracy: {}(best) Avg Loss:{}'.format(cur_epoch, self.avgacc, self.avgloss))
|
||||||
|
best_path = os.path.join(self.path, 'best_ck.ckpt')
|
||||||
|
save_checkpoint(cur_net, best_path)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print('Epoch {}: Avg Accuracy: {} Avg Loss:{}'.format(cur_epoch, self.avgacc, self.avgloss))
|
||||||
|
last_path = os.path.join(self.path, 'last_ck.ckpt')
|
||||||
|
save_checkpoint(cur_net, last_path)
|
||||||
|
print("Best Acc:", self.bestacc)
|
||||||
|
print('=========EPOCH {} END========='.format(cur_epoch))
|
||||||
|
|
||||||
|
def eval(self, inp, net):
|
||||||
|
"""
|
||||||
|
CallBack eval
|
||||||
|
"""
|
||||||
|
avg_acc = list()
|
||||||
|
avg_loss = list()
|
||||||
|
for _ in range(10):
|
||||||
|
for batch in inp.create_dict_iterator():
|
||||||
|
x = batch['data']
|
||||||
|
y = batch['label']
|
||||||
|
classes = batch['classes']
|
||||||
|
acc, loss = net(x, y, classes)
|
||||||
|
avg_acc.append(acc.asnumpy())
|
||||||
|
avg_loss.append(loss.asnumpy())
|
||||||
|
avg_acc = np.mean(avg_acc)
|
||||||
|
avg_loss = np.mean(avg_loss)
|
||||||
|
|
||||||
|
return avg_acc, avg_loss
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
dataset iter generator script.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class IterDatasetGenerator:
|
||||||
|
"""
|
||||||
|
dataloader class
|
||||||
|
"""
|
||||||
|
def __init__(self, data, classes_per_it, num_samples, iterations):
|
||||||
|
self.__iterations = iterations
|
||||||
|
self.__data = data.x
|
||||||
|
self.__labels = data.y
|
||||||
|
self.__iter = 0
|
||||||
|
self.classes_per_it = classes_per_it
|
||||||
|
self.sample_per_class = num_samples
|
||||||
|
self.classes, self.counts = np.unique(self.__labels, return_counts=True)
|
||||||
|
self.idxs = range(len(self.__labels))
|
||||||
|
self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
|
||||||
|
self.numel_per_class = np.zeros_like(self.classes)
|
||||||
|
for idx, label in tqdm(enumerate(self.__labels)):
|
||||||
|
label_idx = np.argwhere(self.classes == label).item()
|
||||||
|
self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx
|
||||||
|
self.numel_per_class[label_idx] = int(self.numel_per_class[label_idx]) + 1
|
||||||
|
|
||||||
|
print('init end')
|
||||||
|
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
spc = self.sample_per_class
|
||||||
|
cpi = self.classes_per_it
|
||||||
|
|
||||||
|
if self.__iter >= self.__iterations:
|
||||||
|
raise StopIteration
|
||||||
|
batch_size = spc * cpi
|
||||||
|
batch = np.random.randint(low=batch_size, high=10 * batch_size, size=(batch_size), dtype=np.int32)
|
||||||
|
c_idxs = np.random.permutation(len(self.classes))[:cpi]
|
||||||
|
for indx, c in enumerate(self.classes[c_idxs]):
|
||||||
|
index = indx*spc
|
||||||
|
ci = [c_i for c_i in range(len(self.classes)) if self.classes[c_i] == c][0]
|
||||||
|
label_idx = list(range(len(self.classes)))[ci]
|
||||||
|
sample_idxs = np.random.permutation(int(self.numel_per_class[label_idx]))[:spc]
|
||||||
|
ind = 0
|
||||||
|
for sid in sample_idxs:
|
||||||
|
batch[index+ind] = self.indexes[label_idx][sid]
|
||||||
|
ind = ind + 1
|
||||||
|
batch = batch[np.random.permutation(len(batch))]
|
||||||
|
data_x = []
|
||||||
|
data_y = []
|
||||||
|
for b in batch:
|
||||||
|
data_x.append(self.__data[b])
|
||||||
|
data_y.append(self.__labels[b])
|
||||||
|
self.__iter += 1
|
||||||
|
data_y = np.asarray(data_y, np.int32)
|
||||||
|
data_class = np.asarray(np.unique(data_y), np.int32)
|
||||||
|
item = (data_x, data_y, data_class)
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.__iter = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.__iterations
|
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
loss function script.
|
||||||
|
"""
|
||||||
|
import mindspore.ops as ops
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
import mindspore as ms
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class PrototypicalLoss(_Loss):
|
||||||
|
'''
|
||||||
|
Loss class deriving from Module for the prototypical loss function defined below
|
||||||
|
'''
|
||||||
|
def __init__(self, n_support, n_query, n_class, is_train=True):
|
||||||
|
super(PrototypicalLoss, self).__init__()
|
||||||
|
self.n_support = n_support
|
||||||
|
self.n_query = n_query
|
||||||
|
self.eq = ops.Equal()
|
||||||
|
self.sum = ops.ReduceSum(keep_dims=True)
|
||||||
|
self.log_softmax = nn.LogSoftmax(1)
|
||||||
|
self.gather = ops.GatherD()
|
||||||
|
self.squeeze = ops.Squeeze()
|
||||||
|
self.max = ops.Argmax(2)
|
||||||
|
self.cast = ops.Cast()
|
||||||
|
self.stack = ops.Stack()
|
||||||
|
self.reshape = ops.Reshape()
|
||||||
|
self.topk = ops.TopK(sorted=True)
|
||||||
|
self.expendDims = ops.ExpandDims()
|
||||||
|
self.broadcastTo = ops.BroadcastTo((100, 20, 64))
|
||||||
|
self.pow = ops.Pow()
|
||||||
|
self.sum = ops.ReduceSum()
|
||||||
|
self.zeros = Tensor(np.zeros(200), ms.float32)
|
||||||
|
self.ones = Tensor(np.ones(200), ms.float32)
|
||||||
|
self.print = ops.Print()
|
||||||
|
self.unique = ops.Unique()
|
||||||
|
self.samples_count = 10
|
||||||
|
self.select = ops.Select()
|
||||||
|
self.target_inds = Tensor(list(range(0, n_class)), ms.int32)
|
||||||
|
self.is_train = is_train
|
||||||
|
# self.acc_val = 0
|
||||||
|
|
||||||
|
def construct(self, inp, target, classes):
|
||||||
|
"""
|
||||||
|
loss construct
|
||||||
|
"""
|
||||||
|
n_classes = len(classes)
|
||||||
|
n_query = self.n_query
|
||||||
|
support_idxs = ()
|
||||||
|
query_idxs = ()
|
||||||
|
|
||||||
|
for ind, _ in enumerate(classes):
|
||||||
|
class_c = classes[ind]
|
||||||
|
_, a = self.topk(self.cast(self.eq(target, class_c), ms.float32), self.n_support + self.n_query)
|
||||||
|
support_idx = self.squeeze(a[:self.n_support])
|
||||||
|
support_idxs += (support_idx,)
|
||||||
|
query_idx = a[self.n_support:]
|
||||||
|
query_idxs += (query_idx,)
|
||||||
|
|
||||||
|
|
||||||
|
prototypes = ()
|
||||||
|
for idx_list in support_idxs:
|
||||||
|
prototypes += (inp[idx_list].mean(0),)
|
||||||
|
prototypes = self.stack(prototypes)
|
||||||
|
|
||||||
|
query_idxs = self.stack(query_idxs).view(-1)
|
||||||
|
query_samples = inp[query_idxs]
|
||||||
|
|
||||||
|
dists = euclidean_dist(query_samples, prototypes)
|
||||||
|
|
||||||
|
log_p_y = self.log_softmax(-dists)
|
||||||
|
|
||||||
|
log_p_y = self.reshape(log_p_y, (n_classes, n_query, -1))
|
||||||
|
|
||||||
|
target_inds = self.target_inds.view(n_classes, 1, 1)
|
||||||
|
target_inds = ops.BroadcastTo((n_classes, n_query, 1))(target_inds) # to int64
|
||||||
|
|
||||||
|
loss_val = -self.squeeze(self.gather(log_p_y, 2, target_inds)).view(-1).mean()
|
||||||
|
|
||||||
|
y_hat = self.max(log_p_y)
|
||||||
|
acc_val = self.cast(self.eq(y_hat, self.squeeze(target_inds)), ms.float32).mean()
|
||||||
|
if self.is_train:
|
||||||
|
return loss_val
|
||||||
|
return acc_val, loss_val
|
||||||
|
|
||||||
|
def supp_idxs(self, target, c):
|
||||||
|
return self.squeeze(self.nonZero(self.eq(target, c))[:self.n_support])
|
||||||
|
|
||||||
|
def nonZero(self, inpbool):
|
||||||
|
out = []
|
||||||
|
for _, inp in enumerate(inpbool):
|
||||||
|
if inp:
|
||||||
|
out.append(inp)
|
||||||
|
return Tensor(out, ms.int32)
|
||||||
|
|
||||||
|
def acc(self):
|
||||||
|
return self.acc_val
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_dist(x, y):
|
||||||
|
'''
|
||||||
|
Compute euclidean distance between two tensors
|
||||||
|
'''
|
||||||
|
# x: N x D
|
||||||
|
# y: M x D
|
||||||
|
n = x.shape[0]
|
||||||
|
m = y.shape[0]
|
||||||
|
d = x.shape[1]
|
||||||
|
|
||||||
|
expendDims = ops.ExpandDims()
|
||||||
|
broadcastTo = ops.BroadcastTo((n, m, d))
|
||||||
|
pow_op = ops.Pow()
|
||||||
|
reducesum = ops.ReduceSum()
|
||||||
|
|
||||||
|
x = broadcastTo(expendDims(x, 1))
|
||||||
|
y = broadcastTo(expendDims(y, 0))
|
||||||
|
return reducesum(pow_op(x-y, 2), 2)
|
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
dataset for ProtoNet
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
IMG_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
|
class OmniglotDataset():
|
||||||
|
"""
|
||||||
|
Omniglot dataset class
|
||||||
|
"""
|
||||||
|
|
||||||
|
splits_folder = os.path.join('splits', 'vinyals')
|
||||||
|
raw_folder = 'raw'
|
||||||
|
processed_folder = 'data'
|
||||||
|
|
||||||
|
def __init__(self, mode='train', root='.' + os.sep + 'dataset', transform=None, target_transform=None):
|
||||||
|
self.root = root
|
||||||
|
print(self.root)
|
||||||
|
self.transform = transform
|
||||||
|
self.target_transform = target_transform
|
||||||
|
|
||||||
|
self.classes = get_current_classes(os.path.join(
|
||||||
|
self.root, self.splits_folder, mode + '.txt'))
|
||||||
|
self.all_items = find_items(os.path.join(
|
||||||
|
self.root, self.processed_folder), self.classes)
|
||||||
|
|
||||||
|
self.idx_classes = index_classes(self.all_items)
|
||||||
|
paths, self.y = zip(*[self.get_path_label(pl)
|
||||||
|
for pl in range(len(self))])
|
||||||
|
self.x = map(load_img, paths, range(len(paths)))
|
||||||
|
self.x = list(self.x)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
x = self.x[idx]
|
||||||
|
if self.transform:
|
||||||
|
x = self.transform(x)
|
||||||
|
return x, self.y[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.all_items)
|
||||||
|
|
||||||
|
def get_path_label(self, index):
|
||||||
|
filename = self.all_items[index][0]
|
||||||
|
rot = self.all_items[index][-1]
|
||||||
|
img = str.join(os.sep, [self.all_items[index][2], filename]) + rot
|
||||||
|
target = self.idx_classes[self.all_items[index]
|
||||||
|
[1] + self.all_items[index][-1]]
|
||||||
|
if self.target_transform is not None:
|
||||||
|
target = self.target_transform(target)
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
|
||||||
|
def find_items(root_dir, classes):
|
||||||
|
"""
|
||||||
|
function to find items
|
||||||
|
"""
|
||||||
|
retour = []
|
||||||
|
rots = [os.sep + 'rot000', os.sep + 'rot090', os.sep + 'rot180', os.sep + 'rot270']
|
||||||
|
for (root, _, files) in os.walk(root_dir):
|
||||||
|
for f in files:
|
||||||
|
r = root.split(os.sep)
|
||||||
|
lr = len(r)
|
||||||
|
label = r[lr - 2] + os.sep + r[lr - 1]
|
||||||
|
for rot in rots:
|
||||||
|
if label + rot in classes and (f.endswith("png")):
|
||||||
|
retour.extend([(f, label, root, rot)])
|
||||||
|
print("== Dataset: Found %d items " % len(retour))
|
||||||
|
return retour
|
||||||
|
|
||||||
|
|
||||||
|
def index_classes(items):
|
||||||
|
"""
|
||||||
|
how mach items and classes dataset have
|
||||||
|
"""
|
||||||
|
idx = {}
|
||||||
|
for i in items:
|
||||||
|
if not i[1] + i[-1] in idx:
|
||||||
|
idx[i[1] + i[-1]] = len(idx)
|
||||||
|
print("== Dataset: Found %d classes" % len(idx))
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_classes(fname):
|
||||||
|
"""
|
||||||
|
get current classes
|
||||||
|
"""
|
||||||
|
with open(fname) as f:
|
||||||
|
classes = f.read().replace('/', os.sep).splitlines()
|
||||||
|
return classes
|
||||||
|
|
||||||
|
|
||||||
|
def load_img(path, idx):
|
||||||
|
"""
|
||||||
|
function to load images
|
||||||
|
"""
|
||||||
|
path, rot = path.split(os.sep + 'rot')
|
||||||
|
if path in IMG_CACHE:
|
||||||
|
x = IMG_CACHE[path]
|
||||||
|
else:
|
||||||
|
x = Image.open(path)
|
||||||
|
IMG_CACHE[path] = x
|
||||||
|
x = x.rotate(float(rot))
|
||||||
|
x = x.resize((28, 28))
|
||||||
|
|
||||||
|
shape = 1, x.size[0], x.size[1]
|
||||||
|
x = np.array(x, np.float32, copy=False)
|
||||||
|
x = 1.0 - x
|
||||||
|
x = x.T.reshape(shape)
|
||||||
|
|
||||||
|
return x
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
ProtoNet parser_util script.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
"""
|
||||||
|
ProtoNet parser_util script.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('--run_offline', default=True, help='run in offline is False or True')
|
||||||
|
|
||||||
|
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||||
|
|
||||||
|
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||||
|
|
||||||
|
parser.add_argument('--ckpt_url', default=None, help='Location of training outputs.')
|
||||||
|
|
||||||
|
parser.add_argument('-root', '--dataset_root',
|
||||||
|
type=str,
|
||||||
|
help='path to dataset',
|
||||||
|
default='..' + os.sep + 'dataset')
|
||||||
|
|
||||||
|
parser.add_argument('-target', '--device_target',
|
||||||
|
type=str,
|
||||||
|
help='path to dataset',
|
||||||
|
default='Ascend')
|
||||||
|
parser.add_argument('-id', '--device_id',
|
||||||
|
type=int,
|
||||||
|
help='device is',
|
||||||
|
default=0)
|
||||||
|
|
||||||
|
parser.add_argument('-exp', '--experiment_root',
|
||||||
|
type=str,
|
||||||
|
help='root where to store models, losses and accuracies',
|
||||||
|
default='..' + os.sep + 'output')
|
||||||
|
|
||||||
|
parser.add_argument('-nep', '--epochs',
|
||||||
|
type=int,
|
||||||
|
help='number of epochs to train for',
|
||||||
|
default=2)
|
||||||
|
|
||||||
|
parser.add_argument('-lr', '--learning_rate',
|
||||||
|
type=float,
|
||||||
|
help='learning rate for the model, default=0.001',
|
||||||
|
default=0.001)
|
||||||
|
|
||||||
|
parser.add_argument('-lrS', '--lr_scheduler_step',
|
||||||
|
type=int,
|
||||||
|
help='StepLR learning rate scheduler step, default=20',
|
||||||
|
default=20)
|
||||||
|
|
||||||
|
parser.add_argument('-lrG', '--lr_scheduler_gamma',
|
||||||
|
type=float,
|
||||||
|
help='StepLR learning rate scheduler gamma, default=0.5',
|
||||||
|
default=0.5)
|
||||||
|
|
||||||
|
parser.add_argument('-its', '--iterations',
|
||||||
|
type=int,
|
||||||
|
help='number of episodes per epoch, default=100',
|
||||||
|
default=100)
|
||||||
|
|
||||||
|
parser.add_argument('-cTr', '--classes_per_it_tr',
|
||||||
|
type=int,
|
||||||
|
help='number of random classes per episode for training, default=60',
|
||||||
|
default=20)
|
||||||
|
|
||||||
|
parser.add_argument('-nsTr', '--num_support_tr',
|
||||||
|
type=int,
|
||||||
|
help='number of samples per class to use as support for training, default=5',
|
||||||
|
default=5)
|
||||||
|
|
||||||
|
parser.add_argument('-nqTr', '--num_query_tr',
|
||||||
|
type=int,
|
||||||
|
help='number of samples per class to use as query for training, default=5',
|
||||||
|
default=5)
|
||||||
|
|
||||||
|
parser.add_argument('-cVa', '--classes_per_it_val',
|
||||||
|
type=int,
|
||||||
|
help='number of random classes per episode for validation, default=5',
|
||||||
|
default=20)
|
||||||
|
|
||||||
|
parser.add_argument('-nsVa', '--num_support_val',
|
||||||
|
type=int,
|
||||||
|
help='number of samples per class to use as support for validation, default=5',
|
||||||
|
default=5)
|
||||||
|
|
||||||
|
parser.add_argument('-nqVa', '--num_query_val',
|
||||||
|
type=int,
|
||||||
|
help='number of samples per class to use as query for validation, default=15',
|
||||||
|
default=15)
|
||||||
|
|
||||||
|
parser.add_argument('-seed', '--manual_seed',
|
||||||
|
type=int,
|
||||||
|
help='input for the manual seeds initializations',
|
||||||
|
default=7)
|
||||||
|
|
||||||
|
parser.add_argument('--cuda',
|
||||||
|
action='store_true',
|
||||||
|
help='enables cuda')
|
||||||
|
|
||||||
|
return parser
|
|
@ -0,0 +1,257 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
ProtoNet.
|
||||||
|
"""
|
||||||
|
from functools import reduce
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.ops as ops
|
||||||
|
from mindspore.common import initializer as init
|
||||||
|
|
||||||
|
def _calculate_gain(nonlinearity, param=None):
|
||||||
|
r"""
|
||||||
|
Return the recommended gain value for the given nonlinearity function.
|
||||||
|
|
||||||
|
The values are as follows:
|
||||||
|
================= ====================================================
|
||||||
|
nonlinearity gain
|
||||||
|
================= ====================================================
|
||||||
|
Linear / Identity :math:`1`
|
||||||
|
Conv{1,2,3}D :math:`1`
|
||||||
|
Sigmoid :math:`1`
|
||||||
|
Tanh :math:`\frac{5}{3}`
|
||||||
|
ReLU :math:`\sqrt{2}`
|
||||||
|
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||||
|
================= ====================================================
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nonlinearity: the non-linear function
|
||||||
|
param: optional parameter for the non-linear function
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||||
|
"""
|
||||||
|
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||||
|
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||||
|
return 1
|
||||||
|
if nonlinearity == 'tanh':
|
||||||
|
return 5.0 / 3
|
||||||
|
if nonlinearity == 'relu':
|
||||||
|
return math.sqrt(2.0)
|
||||||
|
if nonlinearity == 'leaky_relu':
|
||||||
|
if param is None:
|
||||||
|
negative_slope = 0.01
|
||||||
|
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||||
|
negative_slope = param
|
||||||
|
else:
|
||||||
|
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||||
|
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||||
|
|
||||||
|
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||||
|
|
||||||
|
def _assignment(arr, num):
|
||||||
|
"""Assign the value of `num` to `arr`."""
|
||||||
|
if arr.shape == ():
|
||||||
|
arr = arr.reshape((1))
|
||||||
|
arr[:] = num
|
||||||
|
arr = arr.reshape(())
|
||||||
|
else:
|
||||||
|
if isinstance(num, np.ndarray):
|
||||||
|
arr[:] = num[:]
|
||||||
|
else:
|
||||||
|
arr[:] = num
|
||||||
|
return arr
|
||||||
|
|
||||||
|
def _calculate_in_and_out(arr):
|
||||||
|
"""
|
||||||
|
Calculate n_in and n_out.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arr (Array): Input array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
||||||
|
"""
|
||||||
|
dim = len(arr.shape)
|
||||||
|
if dim < 2:
|
||||||
|
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
|
||||||
|
|
||||||
|
n_in = arr.shape[1]
|
||||||
|
n_out = arr.shape[0]
|
||||||
|
|
||||||
|
if dim > 2:
|
||||||
|
counter = reduce(lambda x, y: x * y, arr.shape[2:])
|
||||||
|
n_in *= counter
|
||||||
|
n_out *= counter
|
||||||
|
return n_in, n_out
|
||||||
|
|
||||||
|
def _select_fan(array, mode):
|
||||||
|
"""
|
||||||
|
select fan
|
||||||
|
"""
|
||||||
|
mode = mode.lower()
|
||||||
|
valid_modes = ['fan_in', 'fan_out']
|
||||||
|
if mode not in valid_modes:
|
||||||
|
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||||
|
|
||||||
|
fan_in, fan_out = _calculate_in_and_out(array)
|
||||||
|
return fan_in if mode == 'fan_in' else fan_out
|
||||||
|
|
||||||
|
class KaimingInit(init.Initializer):
|
||||||
|
r"""
|
||||||
|
Base Class. Initialize the array with He kaiming algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: the negative slope of the rectifier used after this layer (only
|
||||||
|
used with ``'leaky_relu'``)
|
||||||
|
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||||
|
preserves the magnitude of the variance of the weights in the
|
||||||
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||||
|
backwards pass.
|
||||||
|
nonlinearity: the non-linear function, recommended to use only with
|
||||||
|
``'relu'`` or ``'leaky_relu'`` (default).
|
||||||
|
"""
|
||||||
|
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||||
|
super(KaimingInit, self).__init__()
|
||||||
|
self.mode = mode
|
||||||
|
self.gain = _calculate_gain(nonlinearity, a)
|
||||||
|
def _initialize(self, arr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KaimingUniform(KaimingInit):
|
||||||
|
r"""
|
||||||
|
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
|
||||||
|
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||||
|
|
||||||
|
Input:
|
||||||
|
arr (Array): The array to be assigned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array, assigned array.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> w = np.empty(3, 5)
|
||||||
|
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _initialize(self, arr):
|
||||||
|
fan = _select_fan(arr, self.mode)
|
||||||
|
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||||
|
data = np.random.uniform(-bound, bound, arr.shape)
|
||||||
|
|
||||||
|
_assignment(arr, data)
|
||||||
|
|
||||||
|
|
||||||
|
class KaimingNormal(KaimingInit):
|
||||||
|
r"""
|
||||||
|
Initialize the array with He kaiming normal algorithm. The resulting tensor will
|
||||||
|
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
|
||||||
|
|
||||||
|
Input:
|
||||||
|
arr (Array): The array to be assigned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array, assigned array.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> w = np.empty(3, 5)
|
||||||
|
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _initialize(self, arr):
|
||||||
|
fan = _select_fan(arr, self.mode)
|
||||||
|
std = self.gain / math.sqrt(fan)
|
||||||
|
data = np.random.normal(0, std, arr.shape)
|
||||||
|
|
||||||
|
_assignment(arr, data)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_block(in_channels, out_channels):
|
||||||
|
'''
|
||||||
|
returns a block conv-bn-relu-pool
|
||||||
|
'''
|
||||||
|
return nn.SequentialCell(
|
||||||
|
nn.Conv2d(in_channels, out_channels, 3, pad_mode='pad', padding=1, has_bias=True),
|
||||||
|
nn.BatchNorm2d(out_channels, momentum=0.1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(2, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
class ProtoNet(nn.Cell):
|
||||||
|
'''
|
||||||
|
Model as described in the reference paper,
|
||||||
|
source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
|
||||||
|
'''
|
||||||
|
def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
|
||||||
|
super(ProtoNet, self).__init__()
|
||||||
|
self.encoder = nn.SequentialCell(
|
||||||
|
conv_block(x_dim, hid_dim),
|
||||||
|
conv_block(hid_dim, hid_dim),
|
||||||
|
conv_block(hid_dim, hid_dim),
|
||||||
|
conv_block(hid_dim, z_dim),
|
||||||
|
)
|
||||||
|
self._initialize_weights()
|
||||||
|
self.print = ops.Print()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.encoder(x)
|
||||||
|
reshape = ops.Reshape()
|
||||||
|
x = reshape(x, (x.shape[0], -1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
for _, cell in self.cells_and_names():
|
||||||
|
if isinstance(cell, nn.Conv2d):
|
||||||
|
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||||
|
cell.weight.shape,
|
||||||
|
ms.float32))
|
||||||
|
if cell.bias is not None:
|
||||||
|
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
cell.bias.set_data(init.initializer(init.Uniform(bound),
|
||||||
|
cell.bias.shape,
|
||||||
|
ms.float32))
|
||||||
|
class WithLossCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||||
|
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
|
||||||
|
>>>
|
||||||
|
>>> batch_size = 2
|
||||||
|
>>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||||
|
>>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32))
|
||||||
|
>>>
|
||||||
|
>>> output_data = net_with_criterion(data, label)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backbone, loss_fn):
|
||||||
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||||
|
self._backbone = backbone
|
||||||
|
self._loss_fn = loss_fn
|
||||||
|
|
||||||
|
def construct(self, data, label, classes):
|
||||||
|
out = self._backbone(data)
|
||||||
|
return self._loss_fn(out, label, classes)
|
|
@ -0,0 +1,124 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
ProtoNet train script.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.train import Model
|
||||||
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||||
|
from mindspore import dataset as ds
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from src.EvalCallBack import EvalCallBack
|
||||||
|
from src.protonet import WithLossCell
|
||||||
|
from src.PrototypicalLoss import PrototypicalLoss
|
||||||
|
from src.parser_util import get_parser
|
||||||
|
from src.protonet import ProtoNet
|
||||||
|
from model_init import init_dataloader
|
||||||
|
|
||||||
|
local_data_url = './cache/data'
|
||||||
|
local_train_url = './cache/out'
|
||||||
|
|
||||||
|
|
||||||
|
def train(opt, tr_dataloader, net, loss_fn, eval_loss_fn, optim, path, val_dataloader=None):
|
||||||
|
'''
|
||||||
|
train function
|
||||||
|
'''
|
||||||
|
|
||||||
|
inp = ds.GeneratorDataset(tr_dataloader, column_names=['data', 'label', 'classes'])
|
||||||
|
my_loss_cell = WithLossCell(net, loss_fn)
|
||||||
|
my_acc_cell = WithLossCell(net, eval_loss_fn)
|
||||||
|
model = Model(my_loss_cell, optimizer=optim)
|
||||||
|
|
||||||
|
eval_data = ds.GeneratorDataset(val_dataloader, column_names=['data', 'label', 'classes'])
|
||||||
|
|
||||||
|
eval_cb = EvalCallBack(opt, my_acc_cell, eval_data, path)
|
||||||
|
config = CheckpointConfig(save_checkpoint_steps=10,
|
||||||
|
keep_checkpoint_max=5,
|
||||||
|
saved_network=net)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix='protonet', directory=path, config=config)
|
||||||
|
|
||||||
|
print('==========training test==========')
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
model.train(opt.epochs, inp, callbacks=[ckpoint_cb, eval_cb, TimeMonitor()])
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('epoch time: ', (endtime - starttime).seconds / 10, 'per step time:', (endtime - starttime).seconds / 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
'''
|
||||||
|
main function
|
||||||
|
'''
|
||||||
|
global local_data_url
|
||||||
|
global local_train_url
|
||||||
|
|
||||||
|
options = get_parser().parse_args()
|
||||||
|
|
||||||
|
if options.run_offline:
|
||||||
|
|
||||||
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||||
|
|
||||||
|
if device_num > 1:
|
||||||
|
|
||||||
|
init()
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
context.set_context(device_id=options.device_id)
|
||||||
|
local_data_url = options.dataset_root
|
||||||
|
local_train_url = options.experiment_root
|
||||||
|
if not os.path.exists(options.experiment_root):
|
||||||
|
os.makedirs(options.experiment_root)
|
||||||
|
else:
|
||||||
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||||
|
device_id = int(os.getenv("DEVICE_ID"))
|
||||||
|
|
||||||
|
import moxing as mox
|
||||||
|
if not os.path.exists(local_train_url):
|
||||||
|
os.makedirs(local_train_url)
|
||||||
|
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
|
||||||
|
if device_num > 1:
|
||||||
|
|
||||||
|
init()
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
local_data_url = os.path.join(local_data_url, str(device_id))
|
||||||
|
local_train_url = os.path.join(local_train_url, str(device_id))
|
||||||
|
|
||||||
|
mox.file.copy_parallel(src_url=options.data_url, dst_url=local_data_url)
|
||||||
|
|
||||||
|
tr_dataloader = init_dataloader(options, 'train', local_data_url)
|
||||||
|
val_dataloader = init_dataloader(options, 'val', local_data_url)
|
||||||
|
|
||||||
|
loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr, options.classes_per_it_tr)
|
||||||
|
eval_loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr, options.classes_per_it_val,
|
||||||
|
is_train=False)
|
||||||
|
|
||||||
|
Net = ProtoNet()
|
||||||
|
optim = nn.Adam(params=Net.trainable_params(), learning_rate=0.001)
|
||||||
|
train(options, tr_dataloader, Net, loss_fn, eval_loss_fn, optim, local_train_url, val_dataloader)
|
||||||
|
if not options.run_offline:
|
||||||
|
mox.file.copy_parallel(src_url='./cache/out', dst_url=options.train_url)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
main()
|
Loading…
Reference in New Issue