Merge pull request !15433 from qb/pro1
This commit is contained in:
i-robot 2021-06-29 09:37:48 +00:00 committed by Gitee
commit f939057247
17 changed files with 1397 additions and 0 deletions

View File

@ -0,0 +1,168 @@
# Contents
- [Prototypical-Network Description](#protonet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [protonet-Description](#contents)
PyTorch code for NeuralIPS 2017 paper: [Prototypical Networks for Few-shot Learning](https://arxiv.org/abs/1703.05175)
# [Model Architecture](#contents)
Proto-Net contains 2 parts named Encoder and Relation. The former one has 4 convolution layers, the latter one has 2 convolution layers and 2 linear layers.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [omniglot](https://github.com/brendenlake/omniglot)
- Dataset size 4.02M32462 28*28 in 1622 classes
- Train 1,200 classes
- Test 422 classes
- Data format .png files
- Note Data has been processed in omniglot_resized
- The directory structure is as follows:
```text
└─Data
├─raw
├─spilts
│ vinyals
│ test.txt
│ train.txt
│ val.txt
│ trainval.txt
└─data
Alphabet_of_the_Magi
Angelic
```
# [Environment Requirements](#contents)
- Hardware(Ascend)
- Prepare hardware environment with Ascend.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```shell
# enter script dir, train ProtoNet in standalone
sh run_standalone_train_ascend.sh dataset 1 20 20
# enter script dir, train ProtoNet in distribution
sh run_distribution_ascend.sh dataset rank_table dataset 20
# enter script dir, evaluate ProtoNet
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
```
## [Script and Sample Code](#contents)
```shell
├── cv
├── ProtoNet
├── requirements.txt
├── README.md // descriptions about lenet
├── scripts
│ ├──run_standalone_train_ascend.sh // train in ascend
│ ├──run_standalone_eval_ascend.sh // evaluate in ascend
│ ├──run_distribution_ascend.sh // distribution in ascend
├── src
│ ├──parser_util.py // parameter configuration
│ ├──dataset.py // creating dataset
│ ├──IterDatasetGenerator.py // generate dataset
│ ├──protonet.py // relationnet architecture
│ ├──PrototypicalLoss.py // loss function
├── train.py // training script
├── eval.py // evaluation script
```
## [Script Parameters](#contents)
```python
Major parameters in train.py and config.py as follows:
--class_num: the number of class we use in one step.
--sample_num_per_class: the number of quert data we extract from one class.
--batch_num_per_class: the number of support data we extract from one class.
--data_path: The absolute full path to the train and evaluation datasets.
--episode: Total training epochs.
--test_episode: Total testing episodes
--learning_rate: Learning rate
--device_target: Device where the code will be implemented.
--save_dir: The absolute full path to the checkpoint file saved
after training.
--data_path: Path where the dataset is saved
```
## [Training Process](#contents)
### Training
```bash
# enter script dir, train ProtoNet in standalone
sh run_standalone_train_ascend.sh dataset 1 20 20
```
The model checkpoint will be saved in the current directory.
## [Evaluation Process](#contents)
### Evaluation
Before running the command below, please check the checkpoint path used for evaluation.
```bash
# enter script dir, evaluate ProtoNet
sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20
```
```text
Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | ProtoNet |
| -------------------------- | ---------------------------------------------------------- |
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz; 192cores; Memory 755G |
| uploaded Date | 03/26/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | OMNIGLOT |
| Training Parameters | episode=500, class_num = 5, lr=0.001, classes_per_it_tr=60, num_support_tr=5, num_query_tr=5, classes_per_it_val=20, num_support_val=5, num_query_val=15 |
| Optimizer | Adam |
| Loss Function | Prototypicalloss |
| outputs | Accuracy |
| Loss | 0.002 |
| Speed | 215 ms/step |
| Total time | 3 h 23m (8p) |
| Checkpoint for Fine tuning | 440 KB (.ckpt file) |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/protonet |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,71 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ProtoNet evaluation script.
"""
import os
from mindspore import dataset as ds
from mindspore import load_checkpoint
import mindspore.context as context
from src.protonet import ProtoNet
from src.parser_util import get_parser
from src.PrototypicalLoss import PrototypicalLoss
import numpy as np
from model_init import init_dataloader
from train import WithLossCell
def test(test_dataloader, net):
"""
test function
"""
inp = ds.GeneratorDataset(test_dataloader, column_names=['data', 'label', 'classes'])
avg_acc = list()
avg_loss = list()
for _ in range(10):
i = 0
for batch in inp.create_dict_iterator():
i = i + 1
print(i)
x = batch['data']
y = batch['label']
classes = batch['classes']
acc, loss = net(x, y, classes)
avg_acc.append(acc.asnumpy())
avg_loss.append(loss.asnumpy())
print('eval end')
avg_acc = np.mean(avg_acc)
avg_loss = np.mean(avg_loss)
print('Test Acc: {} Loss: {}'.format(avg_acc, avg_loss))
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE)
options = get_parser().parse_args()
if options.run_offline:
datapath = options.dataset_root
ckptpath = options.experiment_root
else:
import mox
mox.file.copy_parallel(src_url=options.data_url, dst_url='cache/data')
mox.file.copy_parallel(src_url=options.ckpt_url, dst_url='cache/ckpt')
datapath = 'cache/data'
ckptpath = 'cache/ckpt'
Net = ProtoNet()
loss_fn = PrototypicalLoss(options.num_support_val, options.num_query_val,
options.classes_per_it_val, is_train=False)
Net = WithLossCell(Net, loss_fn)
val_dataloader = init_dataloader(options, 'val', datapath)
load_checkpoint(os.path.join(ckptpath, 'best_ck.ckpt'), net=Net)
test(val_dataloader, Net)

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into air, onnx, mindir models"""
import argparse
import numpy as np
import mindspore
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.protonet import ProtoNet as protonet
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="protonet", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == "__main__":
# define fusion network
network = protonet()
# load network checkpoint
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, param_dict)
# export network
inputs = Tensor(np.ones([args.batch_size, 1, 28, 28]), mindspore.float32)
export(network, inputs, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,63 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ProtoNet model init script.
"""
import itertools
import mindspore.nn as nn
import numpy as np
from src.dataset import OmniglotDataset
from src.IterDatasetGenerator import IterDatasetGenerator
def init_lr_scheduler(opt):
'''
Initialize the learning rate scheduler
'''
epochs = opt.epochs
milestone = list(itertools.takewhile(lambda n: n < epochs, itertools.count(1, opt.lr_scheduler_step)))
lr0 = opt.learning_rate
bl = list(np.logspace(0, len(milestone)-1, len(milestone), base=opt.lr_scheduler_gamma))
lr = [lr0*b for b in bl]
lr_epoch = nn.piecewise_constant_lr(milestone, lr)
return lr_epoch
def init_dataset(opt, mode, path):
'''
Initialize the dataset
'''
dataset = OmniglotDataset(mode=mode, root=path)
n_classes = len(np.unique(dataset.y))
if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:
raise(Exception('There are not enough classes in the dataset in order ' +
'to satisfy the chosen classes_per_it. Decrease the ' +
'classes_per_it_{tr/val} option and try again.'))
return dataset
def init_dataloader(opt, mode, path):
'''
Initialize the dataloader
'''
dataset = init_dataset(opt, mode, path)
if 'train' in mode:
classes_per_it = opt.classes_per_it_tr
num_samples = opt.num_support_tr + opt.num_query_tr
else:
classes_per_it = opt.classes_per_it_val
num_samples = opt.num_support_val + opt.num_query_val
dataloader = IterDatasetGenerator(dataset, classes_per_it, num_samples, opt.iterations)
return dataloader

View File

@ -0,0 +1,3 @@
numpy >= 1.17.0
tqdm>= 4.61.0
pillow >= 8.2.0

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 4 ]
then
echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [TRAIN_CLASS]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export DATA_PATH=$2
export TRAIN_CLASS=$3
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
python train.py --data_path=$DATA_PATH \
--device_id=$DEVICE_ID --device_target="Ascend" \
--classes_per_it_tr=$TRAIN_CLASS > log 2>&1 &
cd ..
done

View File

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 4 ]
then
echo "Usage: sh run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [EVAL_CLASS]"
exit 1
fi
export DATA_PATH=$1
export CKPT_PATH=$2
export DEVICE_ID=$3
export EVAL_CLASS=$4
python ../eval.py --dataset_root=$DATA_PATH --experiment_root=$CKPT_PATH \
--device_id=$DEVICE_ID --device_target="Ascend" \
--classes_per_it_val=$EVAL_CLASS > eval_log 2>&1 &

View File

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 4 ]
then
echo "Usage: sh run_standalone_train_ascend.sh [DATA_PATH] [DEVICE_ID] [TRAIN_CLASS] [EPOCHS]"
exit 1
fi
export DATA_PATH=$1
export DEVICE_ID=$2
export TRAIN_CLASS=$3
export EPOCHS=$4
python ../train.py --dataset_root=$DATA_PATH \
--device_id=$DEVICE_ID --device_target="Ascend" \
--classes_per_it_tr=$TRAIN_CLASS \
--epochs=$EPOCHS > log 2>&1 &

View File

@ -0,0 +1,88 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Callback for eval
"""
import os
from mindspore.train.callback import Callback
from mindspore import save_checkpoint
import numpy as np
class EvalCallBack(Callback):
"""
CallBack class
"""
def __init__(self, options, net, eval_dataset, path):
self.net = net
self.eval_dataset = eval_dataset
self.path = path
self.avgacc = 0
self.avgloss = 0
self.bestacc = 0
self.options = options
def epoch_begin(self, run_context):
"""
CallBack epoch begin
"""
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
print('=========EPOCH {} BEGIN========='.format(cur_epoch))
def epoch_end(self, run_context):
"""
CallBack epoch end
"""
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
cur_net = cb_param.network
# print(cur_net)
evalnet = self.net
self.avgacc, self.avgloss = self.eval(self.eval_dataset, evalnet)
if self.avgacc > self.bestacc:
self.bestacc = self.avgacc
print('Epoch {}: Avg Accuracy: {}(best) Avg Loss:{}'.format(cur_epoch, self.avgacc, self.avgloss))
best_path = os.path.join(self.path, 'best_ck.ckpt')
save_checkpoint(cur_net, best_path)
else:
print('Epoch {}: Avg Accuracy: {} Avg Loss:{}'.format(cur_epoch, self.avgacc, self.avgloss))
last_path = os.path.join(self.path, 'last_ck.ckpt')
save_checkpoint(cur_net, last_path)
print("Best Acc:", self.bestacc)
print('=========EPOCH {} END========='.format(cur_epoch))
def eval(self, inp, net):
"""
CallBack eval
"""
avg_acc = list()
avg_loss = list()
for _ in range(10):
for batch in inp.create_dict_iterator():
x = batch['data']
y = batch['label']
classes = batch['classes']
acc, loss = net(x, y, classes)
avg_acc.append(acc.asnumpy())
avg_loss.append(loss.asnumpy())
avg_acc = np.mean(avg_acc)
avg_loss = np.mean(avg_loss)
return avg_acc, avg_loss

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset iter generator script.
"""
import numpy as np
from tqdm import tqdm
class IterDatasetGenerator:
"""
dataloader class
"""
def __init__(self, data, classes_per_it, num_samples, iterations):
self.__iterations = iterations
self.__data = data.x
self.__labels = data.y
self.__iter = 0
self.classes_per_it = classes_per_it
self.sample_per_class = num_samples
self.classes, self.counts = np.unique(self.__labels, return_counts=True)
self.idxs = range(len(self.__labels))
self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
self.numel_per_class = np.zeros_like(self.classes)
for idx, label in tqdm(enumerate(self.__labels)):
label_idx = np.argwhere(self.classes == label).item()
self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx
self.numel_per_class[label_idx] = int(self.numel_per_class[label_idx]) + 1
print('init end')
def __next__(self):
spc = self.sample_per_class
cpi = self.classes_per_it
if self.__iter >= self.__iterations:
raise StopIteration
batch_size = spc * cpi
batch = np.random.randint(low=batch_size, high=10 * batch_size, size=(batch_size), dtype=np.int32)
c_idxs = np.random.permutation(len(self.classes))[:cpi]
for indx, c in enumerate(self.classes[c_idxs]):
index = indx*spc
ci = [c_i for c_i in range(len(self.classes)) if self.classes[c_i] == c][0]
label_idx = list(range(len(self.classes)))[ci]
sample_idxs = np.random.permutation(int(self.numel_per_class[label_idx]))[:spc]
ind = 0
for sid in sample_idxs:
batch[index+ind] = self.indexes[label_idx][sid]
ind = ind + 1
batch = batch[np.random.permutation(len(batch))]
data_x = []
data_y = []
for b in batch:
data_x.append(self.__data[b])
data_y.append(self.__labels[b])
self.__iter += 1
data_y = np.asarray(data_y, np.int32)
data_class = np.asarray(np.unique(data_y), np.int32)
item = (data_x, data_y, data_class)
return item
def __iter__(self):
self.__iter = 0
return self
def __len__(self):
return self.__iterations

View File

@ -0,0 +1,131 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
loss function script.
"""
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn.loss.loss import _Loss
import mindspore as ms
import numpy as np
class PrototypicalLoss(_Loss):
'''
Loss class deriving from Module for the prototypical loss function defined below
'''
def __init__(self, n_support, n_query, n_class, is_train=True):
super(PrototypicalLoss, self).__init__()
self.n_support = n_support
self.n_query = n_query
self.eq = ops.Equal()
self.sum = ops.ReduceSum(keep_dims=True)
self.log_softmax = nn.LogSoftmax(1)
self.gather = ops.GatherD()
self.squeeze = ops.Squeeze()
self.max = ops.Argmax(2)
self.cast = ops.Cast()
self.stack = ops.Stack()
self.reshape = ops.Reshape()
self.topk = ops.TopK(sorted=True)
self.expendDims = ops.ExpandDims()
self.broadcastTo = ops.BroadcastTo((100, 20, 64))
self.pow = ops.Pow()
self.sum = ops.ReduceSum()
self.zeros = Tensor(np.zeros(200), ms.float32)
self.ones = Tensor(np.ones(200), ms.float32)
self.print = ops.Print()
self.unique = ops.Unique()
self.samples_count = 10
self.select = ops.Select()
self.target_inds = Tensor(list(range(0, n_class)), ms.int32)
self.is_train = is_train
# self.acc_val = 0
def construct(self, inp, target, classes):
"""
loss construct
"""
n_classes = len(classes)
n_query = self.n_query
support_idxs = ()
query_idxs = ()
for ind, _ in enumerate(classes):
class_c = classes[ind]
_, a = self.topk(self.cast(self.eq(target, class_c), ms.float32), self.n_support + self.n_query)
support_idx = self.squeeze(a[:self.n_support])
support_idxs += (support_idx,)
query_idx = a[self.n_support:]
query_idxs += (query_idx,)
prototypes = ()
for idx_list in support_idxs:
prototypes += (inp[idx_list].mean(0),)
prototypes = self.stack(prototypes)
query_idxs = self.stack(query_idxs).view(-1)
query_samples = inp[query_idxs]
dists = euclidean_dist(query_samples, prototypes)
log_p_y = self.log_softmax(-dists)
log_p_y = self.reshape(log_p_y, (n_classes, n_query, -1))
target_inds = self.target_inds.view(n_classes, 1, 1)
target_inds = ops.BroadcastTo((n_classes, n_query, 1))(target_inds) # to int64
loss_val = -self.squeeze(self.gather(log_p_y, 2, target_inds)).view(-1).mean()
y_hat = self.max(log_p_y)
acc_val = self.cast(self.eq(y_hat, self.squeeze(target_inds)), ms.float32).mean()
if self.is_train:
return loss_val
return acc_val, loss_val
def supp_idxs(self, target, c):
return self.squeeze(self.nonZero(self.eq(target, c))[:self.n_support])
def nonZero(self, inpbool):
out = []
for _, inp in enumerate(inpbool):
if inp:
out.append(inp)
return Tensor(out, ms.int32)
def acc(self):
return self.acc_val
def euclidean_dist(x, y):
'''
Compute euclidean distance between two tensors
'''
# x: N x D
# y: M x D
n = x.shape[0]
m = y.shape[0]
d = x.shape[1]
expendDims = ops.ExpandDims()
broadcastTo = ops.BroadcastTo((n, m, d))
pow_op = ops.Pow()
reducesum = ops.ReduceSum()
x = broadcastTo(expendDims(x, 1))
y = broadcastTo(expendDims(y, 0))
return reducesum(pow_op(x-y, 2), 2)

View File

@ -0,0 +1,129 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset for ProtoNet
"""
import os
from PIL import Image
import numpy as np
IMG_CACHE = {}
class OmniglotDataset():
"""
Omniglot dataset class
"""
splits_folder = os.path.join('splits', 'vinyals')
raw_folder = 'raw'
processed_folder = 'data'
def __init__(self, mode='train', root='.' + os.sep + 'dataset', transform=None, target_transform=None):
self.root = root
print(self.root)
self.transform = transform
self.target_transform = target_transform
self.classes = get_current_classes(os.path.join(
self.root, self.splits_folder, mode + '.txt'))
self.all_items = find_items(os.path.join(
self.root, self.processed_folder), self.classes)
self.idx_classes = index_classes(self.all_items)
paths, self.y = zip(*[self.get_path_label(pl)
for pl in range(len(self))])
self.x = map(load_img, paths, range(len(paths)))
self.x = list(self.x)
def __getitem__(self, idx):
x = self.x[idx]
if self.transform:
x = self.transform(x)
return x, self.y[idx]
def __len__(self):
return len(self.all_items)
def get_path_label(self, index):
filename = self.all_items[index][0]
rot = self.all_items[index][-1]
img = str.join(os.sep, [self.all_items[index][2], filename]) + rot
target = self.idx_classes[self.all_items[index]
[1] + self.all_items[index][-1]]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def find_items(root_dir, classes):
"""
function to find items
"""
retour = []
rots = [os.sep + 'rot000', os.sep + 'rot090', os.sep + 'rot180', os.sep + 'rot270']
for (root, _, files) in os.walk(root_dir):
for f in files:
r = root.split(os.sep)
lr = len(r)
label = r[lr - 2] + os.sep + r[lr - 1]
for rot in rots:
if label + rot in classes and (f.endswith("png")):
retour.extend([(f, label, root, rot)])
print("== Dataset: Found %d items " % len(retour))
return retour
def index_classes(items):
"""
how mach items and classes dataset have
"""
idx = {}
for i in items:
if not i[1] + i[-1] in idx:
idx[i[1] + i[-1]] = len(idx)
print("== Dataset: Found %d classes" % len(idx))
return idx
def get_current_classes(fname):
"""
get current classes
"""
with open(fname) as f:
classes = f.read().replace('/', os.sep).splitlines()
return classes
def load_img(path, idx):
"""
function to load images
"""
path, rot = path.split(os.sep + 'rot')
if path in IMG_CACHE:
x = IMG_CACHE[path]
else:
x = Image.open(path)
IMG_CACHE[path] = x
x = x.rotate(float(rot))
x = x.resize((28, 28))
shape = 1, x.size[0], x.size[1]
x = np.array(x, np.float32, copy=False)
x = 1.0 - x
x = x.T.reshape(shape)
return x

View File

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ProtoNet parser_util script.
"""
import os
import argparse
def get_parser():
"""
ProtoNet parser_util script.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--run_offline', default=True, help='run in offline is False or True')
parser.add_argument('--data_url', default=None, help='Location of data.')
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
parser.add_argument('--ckpt_url', default=None, help='Location of training outputs.')
parser.add_argument('-root', '--dataset_root',
type=str,
help='path to dataset',
default='..' + os.sep + 'dataset')
parser.add_argument('-target', '--device_target',
type=str,
help='path to dataset',
default='Ascend')
parser.add_argument('-id', '--device_id',
type=int,
help='device is',
default=0)
parser.add_argument('-exp', '--experiment_root',
type=str,
help='root where to store models, losses and accuracies',
default='..' + os.sep + 'output')
parser.add_argument('-nep', '--epochs',
type=int,
help='number of epochs to train for',
default=2)
parser.add_argument('-lr', '--learning_rate',
type=float,
help='learning rate for the model, default=0.001',
default=0.001)
parser.add_argument('-lrS', '--lr_scheduler_step',
type=int,
help='StepLR learning rate scheduler step, default=20',
default=20)
parser.add_argument('-lrG', '--lr_scheduler_gamma',
type=float,
help='StepLR learning rate scheduler gamma, default=0.5',
default=0.5)
parser.add_argument('-its', '--iterations',
type=int,
help='number of episodes per epoch, default=100',
default=100)
parser.add_argument('-cTr', '--classes_per_it_tr',
type=int,
help='number of random classes per episode for training, default=60',
default=20)
parser.add_argument('-nsTr', '--num_support_tr',
type=int,
help='number of samples per class to use as support for training, default=5',
default=5)
parser.add_argument('-nqTr', '--num_query_tr',
type=int,
help='number of samples per class to use as query for training, default=5',
default=5)
parser.add_argument('-cVa', '--classes_per_it_val',
type=int,
help='number of random classes per episode for validation, default=5',
default=20)
parser.add_argument('-nsVa', '--num_support_val',
type=int,
help='number of samples per class to use as support for validation, default=5',
default=5)
parser.add_argument('-nqVa', '--num_query_val',
type=int,
help='number of samples per class to use as query for validation, default=15',
default=15)
parser.add_argument('-seed', '--manual_seed',
type=int,
help='input for the manual seeds initializations',
default=7)
parser.add_argument('--cuda',
action='store_true',
help='enables cuda')
return parser

View File

@ -0,0 +1,257 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ProtoNet.
"""
from functools import reduce
import math
import numpy as np
import mindspore.nn as nn
import mindspore as ms
import mindspore.ops as ops
from mindspore.common import initializer as init
def _calculate_gain(nonlinearity, param=None):
r"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function
param: optional parameter for the non-linear function
Examples:
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of `num` to `arr`."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_in_and_out(arr):
"""
Calculate n_in and n_out.
Args:
arr (Array): Input array.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dim = len(arr.shape)
if dim < 2:
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
n_in = arr.shape[1]
n_out = arr.shape[0]
if dim > 2:
counter = reduce(lambda x, y: x * y, arr.shape[2:])
n_in *= counter
n_out *= counter
return n_in, n_out
def _select_fan(array, mode):
"""
select fan
"""
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_in_and_out(array)
return fan_in if mode == 'fan_in' else fan_out
class KaimingInit(init.Initializer):
r"""
Base Class. Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingInit, self).__init__()
self.mode = mode
self.gain = _calculate_gain(nonlinearity, a)
def _initialize(self, arr):
pass
class KaimingUniform(KaimingInit):
r"""
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
class KaimingNormal(KaimingInit):
r"""
Initialize the array with He kaiming normal algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
std = self.gain / math.sqrt(fan)
data = np.random.normal(0, std, arr.shape)
_assignment(arr, data)
def conv_block(in_channels, out_channels):
'''
returns a block conv-bn-relu-pool
'''
return nn.SequentialCell(
nn.Conv2d(in_channels, out_channels, 3, pad_mode='pad', padding=1, has_bias=True),
nn.BatchNorm2d(out_channels, momentum=0.1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
class ProtoNet(nn.Cell):
'''
Model as described in the reference paper,
source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
'''
def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
super(ProtoNet, self).__init__()
self.encoder = nn.SequentialCell(
conv_block(x_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, z_dim),
)
self._initialize_weights()
self.print = ops.Print()
def construct(self, x):
x = self.encoder(x)
reshape = ops.Reshape()
x = reshape(x, (x.shape[0], -1))
return x
def _initialize_weights(self):
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
ms.float32))
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
ms.float32))
class WithLossCell(nn.Cell):
"""
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
>>>
>>> batch_size = 2
>>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
>>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32))
>>>
>>> output_data = net_with_criterion(data, label)
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, data, label, classes):
out = self._backbone(data)
return self._loss_fn(out, label, classes)

View File

@ -0,0 +1,124 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ProtoNet train script.
"""
import os
import datetime
import mindspore.nn as nn
from mindspore.train import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore import dataset as ds
import mindspore.context as context
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from src.EvalCallBack import EvalCallBack
from src.protonet import WithLossCell
from src.PrototypicalLoss import PrototypicalLoss
from src.parser_util import get_parser
from src.protonet import ProtoNet
from model_init import init_dataloader
local_data_url = './cache/data'
local_train_url = './cache/out'
def train(opt, tr_dataloader, net, loss_fn, eval_loss_fn, optim, path, val_dataloader=None):
'''
train function
'''
inp = ds.GeneratorDataset(tr_dataloader, column_names=['data', 'label', 'classes'])
my_loss_cell = WithLossCell(net, loss_fn)
my_acc_cell = WithLossCell(net, eval_loss_fn)
model = Model(my_loss_cell, optimizer=optim)
eval_data = ds.GeneratorDataset(val_dataloader, column_names=['data', 'label', 'classes'])
eval_cb = EvalCallBack(opt, my_acc_cell, eval_data, path)
config = CheckpointConfig(save_checkpoint_steps=10,
keep_checkpoint_max=5,
saved_network=net)
ckpoint_cb = ModelCheckpoint(prefix='protonet', directory=path, config=config)
print('==========training test==========')
starttime = datetime.datetime.now()
model.train(opt.epochs, inp, callbacks=[ckpoint_cb, eval_cb, TimeMonitor()])
endtime = datetime.datetime.now()
print('epoch time: ', (endtime - starttime).seconds / 10, 'per step time:', (endtime - starttime).seconds / 1000)
def main():
'''
main function
'''
global local_data_url
global local_train_url
options = get_parser().parse_args()
if options.run_offline:
device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_num > 1:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
context.set_context(device_id=options.device_id)
local_data_url = options.dataset_root
local_train_url = options.experiment_root
if not os.path.exists(options.experiment_root):
os.makedirs(options.experiment_root)
else:
device_num = int(os.environ.get("DEVICE_NUM", 1))
device_id = int(os.getenv("DEVICE_ID"))
import moxing as mox
if not os.path.exists(local_train_url):
os.makedirs(local_train_url)
context.set_context(device_id=device_id)
if device_num > 1:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
local_data_url = os.path.join(local_data_url, str(device_id))
local_train_url = os.path.join(local_train_url, str(device_id))
mox.file.copy_parallel(src_url=options.data_url, dst_url=local_data_url)
tr_dataloader = init_dataloader(options, 'train', local_data_url)
val_dataloader = init_dataloader(options, 'val', local_data_url)
loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr, options.classes_per_it_tr)
eval_loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr, options.classes_per_it_val,
is_train=False)
Net = ProtoNet()
optim = nn.Adam(params=Net.trainable_params(), learning_rate=0.001)
train(options, tr_dataloader, Net, loss_fn, eval_loss_fn, optim, local_train_url, val_dataloader)
if not options.run_offline:
mox.file.copy_parallel(src_url='./cache/out', dst_url=options.train_url)
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE)
main()