commit
f939057247
|
@ -0,0 +1,168 @@
|
|||
# Contents
|
||||
|
||||
- [Prototypical-Network Description](#protonet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [protonet-Description](#contents)
|
||||
|
||||
PyTorch code for NeuralIPS 2017 paper: [Prototypical Networks for Few-shot Learning](https://arxiv.org/abs/1703.05175)
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Proto-Net contains 2 parts named Encoder and Relation. The former one has 4 convolution layers, the latter one has 2 convolution layers and 2 linear layers.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
|
||||
|
||||
- Dataset size 4.02M,32462 28*28 in 1622 classes
|
||||
- Train 1,200 classes
|
||||
- Test 422 classes
|
||||
- Data format .png files
|
||||
- Note Data has been processed in omniglot_resized
|
||||
|
||||
- The directory structure is as follows:
|
||||
|
||||
```text
|
||||
└─Data
|
||||
├─raw
|
||||
├─spilts
|
||||
│ vinyals
|
||||
│ test.txt
|
||||
│ train.txt
|
||||
│ val.txt
|
||||
│ trainval.txt
|
||||
└─data
|
||||
Alphabet_of_the_Magi
|
||||
Angelic
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```shell
|
||||
# enter script dir, train ProtoNet in standalone
|
||||
sh run_standalone_train_ascend.sh dataset 1 20 20
|
||||
# enter script dir, train ProtoNet in distribution
|
||||
sh run_distribution_ascend.sh dataset rank_table dataset 20
|
||||
# enter script dir, evaluate ProtoNet
|
||||
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
|
||||
```
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
├── cv
|
||||
├── ProtoNet
|
||||
├── requirements.txt
|
||||
├── README.md // descriptions about lenet
|
||||
├── scripts
|
||||
│ ├──run_standalone_train_ascend.sh // train in ascend
|
||||
│ ├──run_standalone_eval_ascend.sh // evaluate in ascend
|
||||
│ ├──run_distribution_ascend.sh // distribution in ascend
|
||||
├── src
|
||||
│ ├──parser_util.py // parameter configuration
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──IterDatasetGenerator.py // generate dataset
|
||||
│ ├──protonet.py // relationnet architecture
|
||||
│ ├──PrototypicalLoss.py // loss function
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```python
|
||||
Major parameters in train.py and config.py as follows:
|
||||
|
||||
--class_num: the number of class we use in one step.
|
||||
--sample_num_per_class: the number of quert data we extract from one class.
|
||||
--batch_num_per_class: the number of support data we extract from one class.
|
||||
--data_path: The absolute full path to the train and evaluation datasets.
|
||||
--episode: Total training epochs.
|
||||
--test_episode: Total testing episodes
|
||||
--learning_rate: Learning rate
|
||||
--device_target: Device where the code will be implemented.
|
||||
--save_dir: The absolute full path to the checkpoint file saved
|
||||
after training.
|
||||
--data_path: Path where the dataset is saved
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
# enter script dir, train ProtoNet in standalone
|
||||
sh run_standalone_train_ascend.sh dataset 1 20 20
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the current directory.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
```bash
|
||||
# enter script dir, evaluate ProtoNet
|
||||
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
|
||||
```
|
||||
|
||||
```text
|
||||
Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | ProtoNet |
|
||||
| -------------------------- | ---------------------------------------------------------- |
|
||||
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz; 192cores; Memory 755G |
|
||||
| uploaded Date | 03/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | OMNIGLOT |
|
||||
| Training Parameters | episode=500, class_num = 5, lr=0.001, classes_per_it_tr=60, num_support_tr=5, num_query_tr=5, classes_per_it_val=20, num_support_val=5, num_query_val=15 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Prototypicalloss |
|
||||
| outputs | Accuracy |
|
||||
| Loss | 0.002 |
|
||||
| Speed | 215 ms/step |
|
||||
| Total time | 3 h 23m (8p) |
|
||||
| Checkpoint for Fine tuning | 440 KB (.ckpt file) |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/protonet |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
ProtoNet evaluation script.
|
||||
"""
|
||||
import os
|
||||
from mindspore import dataset as ds
|
||||
from mindspore import load_checkpoint
|
||||
import mindspore.context as context
|
||||
from src.protonet import ProtoNet
|
||||
from src.parser_util import get_parser
|
||||
from src.PrototypicalLoss import PrototypicalLoss
|
||||
import numpy as np
|
||||
from model_init import init_dataloader
|
||||
from train import WithLossCell
|
||||
|
||||
|
||||
def test(test_dataloader, net):
|
||||
"""
|
||||
test function
|
||||
"""
|
||||
inp = ds.GeneratorDataset(test_dataloader, column_names=['data', 'label', 'classes'])
|
||||
avg_acc = list()
|
||||
avg_loss = list()
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for batch in inp.create_dict_iterator():
|
||||
i = i + 1
|
||||
print(i)
|
||||
x = batch['data']
|
||||
y = batch['label']
|
||||
classes = batch['classes']
|
||||
acc, loss = net(x, y, classes)
|
||||
avg_acc.append(acc.asnumpy())
|
||||
avg_loss.append(loss.asnumpy())
|
||||
print('eval end')
|
||||
avg_acc = np.mean(avg_acc)
|
||||
avg_loss = np.mean(avg_loss)
|
||||
print('Test Acc: {} Loss: {}'.format(avg_acc, avg_loss))
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
options = get_parser().parse_args()
|
||||
if options.run_offline:
|
||||
datapath = options.dataset_root
|
||||
ckptpath = options.experiment_root
|
||||
else:
|
||||
import mox
|
||||
mox.file.copy_parallel(src_url=options.data_url, dst_url='cache/data')
|
||||
mox.file.copy_parallel(src_url=options.ckpt_url, dst_url='cache/ckpt')
|
||||
datapath = 'cache/data'
|
||||
ckptpath = 'cache/ckpt'
|
||||
Net = ProtoNet()
|
||||
loss_fn = PrototypicalLoss(options.num_support_val, options.num_query_val,
|
||||
options.classes_per_it_val, is_train=False)
|
||||
Net = WithLossCell(Net, loss_fn)
|
||||
val_dataloader = init_dataloader(options, 'val', datapath)
|
||||
load_checkpoint(os.path.join(ckptpath, 'best_ck.ckpt'), net=Net)
|
||||
test(val_dataloader, Net)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.protonet import ProtoNet as protonet
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="protonet", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# define fusion network
|
||||
network = protonet()
|
||||
# load network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([args.batch_size, 1, 28, 28]), mindspore.float32)
|
||||
export(network, inputs, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
ProtoNet model init script.
|
||||
"""
|
||||
import itertools
|
||||
import mindspore.nn as nn
|
||||
import numpy as np
|
||||
from src.dataset import OmniglotDataset
|
||||
from src.IterDatasetGenerator import IterDatasetGenerator
|
||||
|
||||
def init_lr_scheduler(opt):
|
||||
'''
|
||||
Initialize the learning rate scheduler
|
||||
'''
|
||||
epochs = opt.epochs
|
||||
milestone = list(itertools.takewhile(lambda n: n < epochs, itertools.count(1, opt.lr_scheduler_step)))
|
||||
|
||||
lr0 = opt.learning_rate
|
||||
bl = list(np.logspace(0, len(milestone)-1, len(milestone), base=opt.lr_scheduler_gamma))
|
||||
lr = [lr0*b for b in bl]
|
||||
lr_epoch = nn.piecewise_constant_lr(milestone, lr)
|
||||
return lr_epoch
|
||||
|
||||
def init_dataset(opt, mode, path):
|
||||
'''
|
||||
Initialize the dataset
|
||||
'''
|
||||
dataset = OmniglotDataset(mode=mode, root=path)
|
||||
n_classes = len(np.unique(dataset.y))
|
||||
if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:
|
||||
raise(Exception('There are not enough classes in the dataset in order ' +
|
||||
'to satisfy the chosen classes_per_it. Decrease the ' +
|
||||
'classes_per_it_{tr/val} option and try again.'))
|
||||
return dataset
|
||||
|
||||
def init_dataloader(opt, mode, path):
|
||||
'''
|
||||
Initialize the dataloader
|
||||
'''
|
||||
dataset = init_dataset(opt, mode, path)
|
||||
if 'train' in mode:
|
||||
classes_per_it = opt.classes_per_it_tr
|
||||
num_samples = opt.num_support_tr + opt.num_query_tr
|
||||
|
||||
else:
|
||||
classes_per_it = opt.classes_per_it_val
|
||||
num_samples = opt.num_support_val + opt.num_query_val
|
||||
|
||||
dataloader = IterDatasetGenerator(dataset, classes_per_it, num_samples, opt.iterations)
|
||||
return dataloader
|
|
@ -0,0 +1,3 @@
|
|||
numpy >= 1.17.0
|
||||
tqdm>= 4.61.0
|
||||
pillow >= 8.2.0
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [TRAIN_CLASS]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
RANK_TABLE_FILE=$(realpath $1)
|
||||
export RANK_TABLE_FILE
|
||||
export DATA_PATH=$2
|
||||
export TRAIN_CLASS=$3
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --data_path=$DATA_PATH \
|
||||
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||
--classes_per_it_tr=$TRAIN_CLASS > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [EVAL_CLASS]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DATA_PATH=$1
|
||||
export CKPT_PATH=$2
|
||||
export DEVICE_ID=$3
|
||||
export EVAL_CLASS=$4
|
||||
|
||||
python ../eval.py --dataset_root=$DATA_PATH --experiment_root=$CKPT_PATH \
|
||||
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||
--classes_per_it_val=$EVAL_CLASS > eval_log 2>&1 &
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train_ascend.sh [DATA_PATH] [DEVICE_ID] [TRAIN_CLASS] [EPOCHS]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DATA_PATH=$1
|
||||
export DEVICE_ID=$2
|
||||
export TRAIN_CLASS=$3
|
||||
export EPOCHS=$4
|
||||
|
||||
python ../train.py --dataset_root=$DATA_PATH \
|
||||
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||
--classes_per_it_tr=$TRAIN_CLASS \
|
||||
--epochs=$EPOCHS > log 2>&1 &
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Callback for eval
|
||||
"""
|
||||
|
||||
import os
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import save_checkpoint
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
CallBack class
|
||||
"""
|
||||
def __init__(self, options, net, eval_dataset, path):
|
||||
self.net = net
|
||||
self.eval_dataset = eval_dataset
|
||||
self.path = path
|
||||
self.avgacc = 0
|
||||
self.avgloss = 0
|
||||
self.bestacc = 0
|
||||
self.options = options
|
||||
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""
|
||||
CallBack epoch begin
|
||||
"""
|
||||
cb_param = run_context.original_args()
|
||||
cur_epoch = cb_param.cur_epoch_num
|
||||
print('=========EPOCH {} BEGIN========='.format(cur_epoch))
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
CallBack epoch end
|
||||
"""
|
||||
cb_param = run_context.original_args()
|
||||
cur_epoch = cb_param.cur_epoch_num
|
||||
cur_net = cb_param.network
|
||||
# print(cur_net)
|
||||
evalnet = self.net
|
||||
self.avgacc, self.avgloss = self.eval(self.eval_dataset, evalnet)
|
||||
|
||||
if self.avgacc > self.bestacc:
|
||||
self.bestacc = self.avgacc
|
||||
print('Epoch {}: Avg Accuracy: {}(best) Avg Loss:{}'.format(cur_epoch, self.avgacc, self.avgloss))
|
||||
best_path = os.path.join(self.path, 'best_ck.ckpt')
|
||||
save_checkpoint(cur_net, best_path)
|
||||
|
||||
else:
|
||||
print('Epoch {}: Avg Accuracy: {} Avg Loss:{}'.format(cur_epoch, self.avgacc, self.avgloss))
|
||||
last_path = os.path.join(self.path, 'last_ck.ckpt')
|
||||
save_checkpoint(cur_net, last_path)
|
||||
print("Best Acc:", self.bestacc)
|
||||
print('=========EPOCH {} END========='.format(cur_epoch))
|
||||
|
||||
def eval(self, inp, net):
|
||||
"""
|
||||
CallBack eval
|
||||
"""
|
||||
avg_acc = list()
|
||||
avg_loss = list()
|
||||
for _ in range(10):
|
||||
for batch in inp.create_dict_iterator():
|
||||
x = batch['data']
|
||||
y = batch['label']
|
||||
classes = batch['classes']
|
||||
acc, loss = net(x, y, classes)
|
||||
avg_acc.append(acc.asnumpy())
|
||||
avg_loss.append(loss.asnumpy())
|
||||
avg_acc = np.mean(avg_acc)
|
||||
avg_loss = np.mean(avg_loss)
|
||||
|
||||
return avg_acc, avg_loss
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
dataset iter generator script.
|
||||
"""
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class IterDatasetGenerator:
|
||||
"""
|
||||
dataloader class
|
||||
"""
|
||||
def __init__(self, data, classes_per_it, num_samples, iterations):
|
||||
self.__iterations = iterations
|
||||
self.__data = data.x
|
||||
self.__labels = data.y
|
||||
self.__iter = 0
|
||||
self.classes_per_it = classes_per_it
|
||||
self.sample_per_class = num_samples
|
||||
self.classes, self.counts = np.unique(self.__labels, return_counts=True)
|
||||
self.idxs = range(len(self.__labels))
|
||||
self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
|
||||
self.numel_per_class = np.zeros_like(self.classes)
|
||||
for idx, label in tqdm(enumerate(self.__labels)):
|
||||
label_idx = np.argwhere(self.classes == label).item()
|
||||
self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx
|
||||
self.numel_per_class[label_idx] = int(self.numel_per_class[label_idx]) + 1
|
||||
|
||||
print('init end')
|
||||
|
||||
|
||||
def __next__(self):
|
||||
spc = self.sample_per_class
|
||||
cpi = self.classes_per_it
|
||||
|
||||
if self.__iter >= self.__iterations:
|
||||
raise StopIteration
|
||||
batch_size = spc * cpi
|
||||
batch = np.random.randint(low=batch_size, high=10 * batch_size, size=(batch_size), dtype=np.int32)
|
||||
c_idxs = np.random.permutation(len(self.classes))[:cpi]
|
||||
for indx, c in enumerate(self.classes[c_idxs]):
|
||||
index = indx*spc
|
||||
ci = [c_i for c_i in range(len(self.classes)) if self.classes[c_i] == c][0]
|
||||
label_idx = list(range(len(self.classes)))[ci]
|
||||
sample_idxs = np.random.permutation(int(self.numel_per_class[label_idx]))[:spc]
|
||||
ind = 0
|
||||
for sid in sample_idxs:
|
||||
batch[index+ind] = self.indexes[label_idx][sid]
|
||||
ind = ind + 1
|
||||
batch = batch[np.random.permutation(len(batch))]
|
||||
data_x = []
|
||||
data_y = []
|
||||
for b in batch:
|
||||
data_x.append(self.__data[b])
|
||||
data_y.append(self.__labels[b])
|
||||
self.__iter += 1
|
||||
data_y = np.asarray(data_y, np.int32)
|
||||
data_class = np.asarray(np.unique(data_y), np.int32)
|
||||
item = (data_x, data_y, data_class)
|
||||
return item
|
||||
|
||||
def __iter__(self):
|
||||
self.__iter = 0
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return self.__iterations
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
loss function script.
|
||||
"""
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
import mindspore as ms
|
||||
import numpy as np
|
||||
|
||||
class PrototypicalLoss(_Loss):
|
||||
'''
|
||||
Loss class deriving from Module for the prototypical loss function defined below
|
||||
'''
|
||||
def __init__(self, n_support, n_query, n_class, is_train=True):
|
||||
super(PrototypicalLoss, self).__init__()
|
||||
self.n_support = n_support
|
||||
self.n_query = n_query
|
||||
self.eq = ops.Equal()
|
||||
self.sum = ops.ReduceSum(keep_dims=True)
|
||||
self.log_softmax = nn.LogSoftmax(1)
|
||||
self.gather = ops.GatherD()
|
||||
self.squeeze = ops.Squeeze()
|
||||
self.max = ops.Argmax(2)
|
||||
self.cast = ops.Cast()
|
||||
self.stack = ops.Stack()
|
||||
self.reshape = ops.Reshape()
|
||||
self.topk = ops.TopK(sorted=True)
|
||||
self.expendDims = ops.ExpandDims()
|
||||
self.broadcastTo = ops.BroadcastTo((100, 20, 64))
|
||||
self.pow = ops.Pow()
|
||||
self.sum = ops.ReduceSum()
|
||||
self.zeros = Tensor(np.zeros(200), ms.float32)
|
||||
self.ones = Tensor(np.ones(200), ms.float32)
|
||||
self.print = ops.Print()
|
||||
self.unique = ops.Unique()
|
||||
self.samples_count = 10
|
||||
self.select = ops.Select()
|
||||
self.target_inds = Tensor(list(range(0, n_class)), ms.int32)
|
||||
self.is_train = is_train
|
||||
# self.acc_val = 0
|
||||
|
||||
def construct(self, inp, target, classes):
|
||||
"""
|
||||
loss construct
|
||||
"""
|
||||
n_classes = len(classes)
|
||||
n_query = self.n_query
|
||||
support_idxs = ()
|
||||
query_idxs = ()
|
||||
|
||||
for ind, _ in enumerate(classes):
|
||||
class_c = classes[ind]
|
||||
_, a = self.topk(self.cast(self.eq(target, class_c), ms.float32), self.n_support + self.n_query)
|
||||
support_idx = self.squeeze(a[:self.n_support])
|
||||
support_idxs += (support_idx,)
|
||||
query_idx = a[self.n_support:]
|
||||
query_idxs += (query_idx,)
|
||||
|
||||
|
||||
prototypes = ()
|
||||
for idx_list in support_idxs:
|
||||
prototypes += (inp[idx_list].mean(0),)
|
||||
prototypes = self.stack(prototypes)
|
||||
|
||||
query_idxs = self.stack(query_idxs).view(-1)
|
||||
query_samples = inp[query_idxs]
|
||||
|
||||
dists = euclidean_dist(query_samples, prototypes)
|
||||
|
||||
log_p_y = self.log_softmax(-dists)
|
||||
|
||||
log_p_y = self.reshape(log_p_y, (n_classes, n_query, -1))
|
||||
|
||||
target_inds = self.target_inds.view(n_classes, 1, 1)
|
||||
target_inds = ops.BroadcastTo((n_classes, n_query, 1))(target_inds) # to int64
|
||||
|
||||
loss_val = -self.squeeze(self.gather(log_p_y, 2, target_inds)).view(-1).mean()
|
||||
|
||||
y_hat = self.max(log_p_y)
|
||||
acc_val = self.cast(self.eq(y_hat, self.squeeze(target_inds)), ms.float32).mean()
|
||||
if self.is_train:
|
||||
return loss_val
|
||||
return acc_val, loss_val
|
||||
|
||||
def supp_idxs(self, target, c):
|
||||
return self.squeeze(self.nonZero(self.eq(target, c))[:self.n_support])
|
||||
|
||||
def nonZero(self, inpbool):
|
||||
out = []
|
||||
for _, inp in enumerate(inpbool):
|
||||
if inp:
|
||||
out.append(inp)
|
||||
return Tensor(out, ms.int32)
|
||||
|
||||
def acc(self):
|
||||
return self.acc_val
|
||||
|
||||
|
||||
def euclidean_dist(x, y):
|
||||
'''
|
||||
Compute euclidean distance between two tensors
|
||||
'''
|
||||
# x: N x D
|
||||
# y: M x D
|
||||
n = x.shape[0]
|
||||
m = y.shape[0]
|
||||
d = x.shape[1]
|
||||
|
||||
expendDims = ops.ExpandDims()
|
||||
broadcastTo = ops.BroadcastTo((n, m, d))
|
||||
pow_op = ops.Pow()
|
||||
reducesum = ops.ReduceSum()
|
||||
|
||||
x = broadcastTo(expendDims(x, 1))
|
||||
y = broadcastTo(expendDims(y, 0))
|
||||
return reducesum(pow_op(x-y, 2), 2)
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
dataset for ProtoNet
|
||||
"""
|
||||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
IMG_CACHE = {}
|
||||
|
||||
|
||||
class OmniglotDataset():
|
||||
"""
|
||||
Omniglot dataset class
|
||||
"""
|
||||
|
||||
splits_folder = os.path.join('splits', 'vinyals')
|
||||
raw_folder = 'raw'
|
||||
processed_folder = 'data'
|
||||
|
||||
def __init__(self, mode='train', root='.' + os.sep + 'dataset', transform=None, target_transform=None):
|
||||
self.root = root
|
||||
print(self.root)
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
self.classes = get_current_classes(os.path.join(
|
||||
self.root, self.splits_folder, mode + '.txt'))
|
||||
self.all_items = find_items(os.path.join(
|
||||
self.root, self.processed_folder), self.classes)
|
||||
|
||||
self.idx_classes = index_classes(self.all_items)
|
||||
paths, self.y = zip(*[self.get_path_label(pl)
|
||||
for pl in range(len(self))])
|
||||
self.x = map(load_img, paths, range(len(paths)))
|
||||
self.x = list(self.x)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = self.x[idx]
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
return x, self.y[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_items)
|
||||
|
||||
def get_path_label(self, index):
|
||||
filename = self.all_items[index][0]
|
||||
rot = self.all_items[index][-1]
|
||||
img = str.join(os.sep, [self.all_items[index][2], filename]) + rot
|
||||
target = self.idx_classes[self.all_items[index]
|
||||
[1] + self.all_items[index][-1]]
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
|
||||
def find_items(root_dir, classes):
|
||||
"""
|
||||
function to find items
|
||||
"""
|
||||
retour = []
|
||||
rots = [os.sep + 'rot000', os.sep + 'rot090', os.sep + 'rot180', os.sep + 'rot270']
|
||||
for (root, _, files) in os.walk(root_dir):
|
||||
for f in files:
|
||||
r = root.split(os.sep)
|
||||
lr = len(r)
|
||||
label = r[lr - 2] + os.sep + r[lr - 1]
|
||||
for rot in rots:
|
||||
if label + rot in classes and (f.endswith("png")):
|
||||
retour.extend([(f, label, root, rot)])
|
||||
print("== Dataset: Found %d items " % len(retour))
|
||||
return retour
|
||||
|
||||
|
||||
def index_classes(items):
|
||||
"""
|
||||
how mach items and classes dataset have
|
||||
"""
|
||||
idx = {}
|
||||
for i in items:
|
||||
if not i[1] + i[-1] in idx:
|
||||
idx[i[1] + i[-1]] = len(idx)
|
||||
print("== Dataset: Found %d classes" % len(idx))
|
||||
return idx
|
||||
|
||||
|
||||
def get_current_classes(fname):
|
||||
"""
|
||||
get current classes
|
||||
"""
|
||||
with open(fname) as f:
|
||||
classes = f.read().replace('/', os.sep).splitlines()
|
||||
return classes
|
||||
|
||||
|
||||
def load_img(path, idx):
|
||||
"""
|
||||
function to load images
|
||||
"""
|
||||
path, rot = path.split(os.sep + 'rot')
|
||||
if path in IMG_CACHE:
|
||||
x = IMG_CACHE[path]
|
||||
else:
|
||||
x = Image.open(path)
|
||||
IMG_CACHE[path] = x
|
||||
x = x.rotate(float(rot))
|
||||
x = x.resize((28, 28))
|
||||
|
||||
shape = 1, x.size[0], x.size[1]
|
||||
x = np.array(x, np.float32, copy=False)
|
||||
x = 1.0 - x
|
||||
x = x.T.reshape(shape)
|
||||
|
||||
return x
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
ProtoNet parser_util script.
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
def get_parser():
|
||||
"""
|
||||
ProtoNet parser_util script.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--run_offline', default=True, help='run in offline is False or True')
|
||||
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
|
||||
parser.add_argument('--ckpt_url', default=None, help='Location of training outputs.')
|
||||
|
||||
parser.add_argument('-root', '--dataset_root',
|
||||
type=str,
|
||||
help='path to dataset',
|
||||
default='..' + os.sep + 'dataset')
|
||||
|
||||
parser.add_argument('-target', '--device_target',
|
||||
type=str,
|
||||
help='path to dataset',
|
||||
default='Ascend')
|
||||
parser.add_argument('-id', '--device_id',
|
||||
type=int,
|
||||
help='device is',
|
||||
default=0)
|
||||
|
||||
parser.add_argument('-exp', '--experiment_root',
|
||||
type=str,
|
||||
help='root where to store models, losses and accuracies',
|
||||
default='..' + os.sep + 'output')
|
||||
|
||||
parser.add_argument('-nep', '--epochs',
|
||||
type=int,
|
||||
help='number of epochs to train for',
|
||||
default=2)
|
||||
|
||||
parser.add_argument('-lr', '--learning_rate',
|
||||
type=float,
|
||||
help='learning rate for the model, default=0.001',
|
||||
default=0.001)
|
||||
|
||||
parser.add_argument('-lrS', '--lr_scheduler_step',
|
||||
type=int,
|
||||
help='StepLR learning rate scheduler step, default=20',
|
||||
default=20)
|
||||
|
||||
parser.add_argument('-lrG', '--lr_scheduler_gamma',
|
||||
type=float,
|
||||
help='StepLR learning rate scheduler gamma, default=0.5',
|
||||
default=0.5)
|
||||
|
||||
parser.add_argument('-its', '--iterations',
|
||||
type=int,
|
||||
help='number of episodes per epoch, default=100',
|
||||
default=100)
|
||||
|
||||
parser.add_argument('-cTr', '--classes_per_it_tr',
|
||||
type=int,
|
||||
help='number of random classes per episode for training, default=60',
|
||||
default=20)
|
||||
|
||||
parser.add_argument('-nsTr', '--num_support_tr',
|
||||
type=int,
|
||||
help='number of samples per class to use as support for training, default=5',
|
||||
default=5)
|
||||
|
||||
parser.add_argument('-nqTr', '--num_query_tr',
|
||||
type=int,
|
||||
help='number of samples per class to use as query for training, default=5',
|
||||
default=5)
|
||||
|
||||
parser.add_argument('-cVa', '--classes_per_it_val',
|
||||
type=int,
|
||||
help='number of random classes per episode for validation, default=5',
|
||||
default=20)
|
||||
|
||||
parser.add_argument('-nsVa', '--num_support_val',
|
||||
type=int,
|
||||
help='number of samples per class to use as support for validation, default=5',
|
||||
default=5)
|
||||
|
||||
parser.add_argument('-nqVa', '--num_query_val',
|
||||
type=int,
|
||||
help='number of samples per class to use as query for validation, default=15',
|
||||
default=15)
|
||||
|
||||
parser.add_argument('-seed', '--manual_seed',
|
||||
type=int,
|
||||
help='input for the manual seeds initializations',
|
||||
default=7)
|
||||
|
||||
parser.add_argument('--cuda',
|
||||
action='store_true',
|
||||
help='enables cuda')
|
||||
|
||||
return parser
|
|
@ -0,0 +1,257 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
ProtoNet.
|
||||
"""
|
||||
from functools import reduce
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore as ms
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common import initializer as init
|
||||
|
||||
def _calculate_gain(nonlinearity, param=None):
|
||||
r"""
|
||||
Return the recommended gain value for the given nonlinearity function.
|
||||
|
||||
The values are as follows:
|
||||
================= ====================================================
|
||||
nonlinearity gain
|
||||
================= ====================================================
|
||||
Linear / Identity :math:`1`
|
||||
Conv{1,2,3}D :math:`1`
|
||||
Sigmoid :math:`1`
|
||||
Tanh :math:`\frac{5}{3}`
|
||||
ReLU :math:`\sqrt{2}`
|
||||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||
================= ====================================================
|
||||
|
||||
Args:
|
||||
nonlinearity: the non-linear function
|
||||
param: optional parameter for the non-linear function
|
||||
|
||||
Examples:
|
||||
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||
"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
return 1
|
||||
if nonlinearity == 'tanh':
|
||||
return 5.0 / 3
|
||||
if nonlinearity == 'relu':
|
||||
return math.sqrt(2.0)
|
||||
if nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
def _assignment(arr, num):
|
||||
"""Assign the value of `num` to `arr`."""
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr = arr.reshape(())
|
||||
else:
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
return arr
|
||||
|
||||
def _calculate_in_and_out(arr):
|
||||
"""
|
||||
Calculate n_in and n_out.
|
||||
|
||||
Args:
|
||||
arr (Array): Input array.
|
||||
|
||||
Returns:
|
||||
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
||||
"""
|
||||
dim = len(arr.shape)
|
||||
if dim < 2:
|
||||
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
|
||||
|
||||
n_in = arr.shape[1]
|
||||
n_out = arr.shape[0]
|
||||
|
||||
if dim > 2:
|
||||
counter = reduce(lambda x, y: x * y, arr.shape[2:])
|
||||
n_in *= counter
|
||||
n_out *= counter
|
||||
return n_in, n_out
|
||||
|
||||
def _select_fan(array, mode):
|
||||
"""
|
||||
select fan
|
||||
"""
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
|
||||
fan_in, fan_out = _calculate_in_and_out(array)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
class KaimingInit(init.Initializer):
|
||||
r"""
|
||||
Base Class. Initialize the array with He kaiming algorithm.
|
||||
|
||||
Args:
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
used with ``'leaky_relu'``)
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function, recommended to use only with
|
||||
``'relu'`` or ``'leaky_relu'`` (default).
|
||||
"""
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(KaimingInit, self).__init__()
|
||||
self.mode = mode
|
||||
self.gain = _calculate_gain(nonlinearity, a)
|
||||
def _initialize(self, arr):
|
||||
pass
|
||||
|
||||
|
||||
class KaimingUniform(KaimingInit):
|
||||
r"""
|
||||
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
|
||||
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||
|
||||
Input:
|
||||
arr (Array): The array to be assigned.
|
||||
|
||||
Returns:
|
||||
Array, assigned array.
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||
data = np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
class KaimingNormal(KaimingInit):
|
||||
r"""
|
||||
Initialize the array with He kaiming normal algorithm. The resulting tensor will
|
||||
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
|
||||
|
||||
.. math::
|
||||
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
|
||||
|
||||
Input:
|
||||
arr (Array): The array to be assigned.
|
||||
|
||||
Returns:
|
||||
Array, assigned array.
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
|
||||
"""
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
std = self.gain / math.sqrt(fan)
|
||||
data = np.random.normal(0, std, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
def conv_block(in_channels, out_channels):
|
||||
'''
|
||||
returns a block conv-bn-relu-pool
|
||||
'''
|
||||
return nn.SequentialCell(
|
||||
nn.Conv2d(in_channels, out_channels, 3, pad_mode='pad', padding=1, has_bias=True),
|
||||
nn.BatchNorm2d(out_channels, momentum=0.1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2)
|
||||
)
|
||||
|
||||
class ProtoNet(nn.Cell):
|
||||
'''
|
||||
Model as described in the reference paper,
|
||||
source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
|
||||
'''
|
||||
def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
|
||||
super(ProtoNet, self).__init__()
|
||||
self.encoder = nn.SequentialCell(
|
||||
conv_block(x_dim, hid_dim),
|
||||
conv_block(hid_dim, hid_dim),
|
||||
conv_block(hid_dim, hid_dim),
|
||||
conv_block(hid_dim, z_dim),
|
||||
)
|
||||
self._initialize_weights()
|
||||
self.print = ops.Print()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.encoder(x)
|
||||
reshape = ops.Reshape()
|
||||
x = reshape(x, (x.shape[0], -1))
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for _, cell in self.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.shape,
|
||||
ms.float32))
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.set_data(init.initializer(init.Uniform(bound),
|
||||
cell.bias.shape,
|
||||
ms.float32))
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
|
||||
>>>
|
||||
>>> batch_size = 2
|
||||
>>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
>>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32))
|
||||
>>>
|
||||
>>> output_data = net_with_criterion(data, label)
|
||||
"""
|
||||
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, data, label, classes):
|
||||
out = self._backbone(data)
|
||||
return self._loss_fn(out, label, classes)
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
ProtoNet train script.
|
||||
"""
|
||||
import os
|
||||
import datetime
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore import dataset as ds
|
||||
import mindspore.context as context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.context import ParallelMode
|
||||
from src.EvalCallBack import EvalCallBack
|
||||
from src.protonet import WithLossCell
|
||||
from src.PrototypicalLoss import PrototypicalLoss
|
||||
from src.parser_util import get_parser
|
||||
from src.protonet import ProtoNet
|
||||
from model_init import init_dataloader
|
||||
|
||||
local_data_url = './cache/data'
|
||||
local_train_url = './cache/out'
|
||||
|
||||
|
||||
def train(opt, tr_dataloader, net, loss_fn, eval_loss_fn, optim, path, val_dataloader=None):
|
||||
'''
|
||||
train function
|
||||
'''
|
||||
|
||||
inp = ds.GeneratorDataset(tr_dataloader, column_names=['data', 'label', 'classes'])
|
||||
my_loss_cell = WithLossCell(net, loss_fn)
|
||||
my_acc_cell = WithLossCell(net, eval_loss_fn)
|
||||
model = Model(my_loss_cell, optimizer=optim)
|
||||
|
||||
eval_data = ds.GeneratorDataset(val_dataloader, column_names=['data', 'label', 'classes'])
|
||||
|
||||
eval_cb = EvalCallBack(opt, my_acc_cell, eval_data, path)
|
||||
config = CheckpointConfig(save_checkpoint_steps=10,
|
||||
keep_checkpoint_max=5,
|
||||
saved_network=net)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='protonet', directory=path, config=config)
|
||||
|
||||
print('==========training test==========')
|
||||
starttime = datetime.datetime.now()
|
||||
model.train(opt.epochs, inp, callbacks=[ckpoint_cb, eval_cb, TimeMonitor()])
|
||||
endtime = datetime.datetime.now()
|
||||
print('epoch time: ', (endtime - starttime).seconds / 10, 'per step time:', (endtime - starttime).seconds / 1000)
|
||||
|
||||
|
||||
def main():
|
||||
'''
|
||||
main function
|
||||
'''
|
||||
global local_data_url
|
||||
global local_train_url
|
||||
|
||||
options = get_parser().parse_args()
|
||||
|
||||
if options.run_offline:
|
||||
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
|
||||
if device_num > 1:
|
||||
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
context.set_context(device_id=options.device_id)
|
||||
local_data_url = options.dataset_root
|
||||
local_train_url = options.experiment_root
|
||||
if not os.path.exists(options.experiment_root):
|
||||
os.makedirs(options.experiment_root)
|
||||
else:
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
|
||||
import moxing as mox
|
||||
if not os.path.exists(local_train_url):
|
||||
os.makedirs(local_train_url)
|
||||
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if device_num > 1:
|
||||
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
local_data_url = os.path.join(local_data_url, str(device_id))
|
||||
local_train_url = os.path.join(local_train_url, str(device_id))
|
||||
|
||||
mox.file.copy_parallel(src_url=options.data_url, dst_url=local_data_url)
|
||||
|
||||
tr_dataloader = init_dataloader(options, 'train', local_data_url)
|
||||
val_dataloader = init_dataloader(options, 'val', local_data_url)
|
||||
|
||||
loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr, options.classes_per_it_tr)
|
||||
eval_loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr, options.classes_per_it_val,
|
||||
is_train=False)
|
||||
|
||||
Net = ProtoNet()
|
||||
optim = nn.Adam(params=Net.trainable_params(), learning_rate=0.001)
|
||||
train(options, tr_dataloader, Net, loss_fn, eval_loss_fn, optim, local_train_url, val_dataloader)
|
||||
if not options.run_offline:
|
||||
mox.file.copy_parallel(src_url='./cache/out', dst_url=options.train_url)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
main()
|
Loading…
Reference in New Issue