From c591801473a07b389f1523e4d16c0ee39d3e2ded Mon Sep 17 00:00:00 2001 From: jhzj Date: Tue, 20 Apr 2021 20:17:04 +0800 Subject: [PATCH] add protonet1 --- model_zoo/research/cv/ProtoNet/README.md | 168 ++++++++++++ model_zoo/research/cv/ProtoNet/__init__.py | 0 model_zoo/research/cv/ProtoNet/eval.py | 71 +++++ model_zoo/research/cv/ProtoNet/export.py | 49 ++++ model_zoo/research/cv/ProtoNet/model_init.py | 63 +++++ .../research/cv/ProtoNet/requirements.txt | 3 + .../scripts/run_distribution_ascend.sh | 55 ++++ .../scripts/run_standalone_eval_ascend.sh | 30 ++ .../scripts/run_standalone_train_ascend.sh | 31 +++ .../research/cv/ProtoNet/src/EvalCallBack.py | 88 ++++++ .../cv/ProtoNet/src/IterDatasetGenerator.py | 80 ++++++ .../cv/ProtoNet/src/PrototypicalLoss.py | 131 +++++++++ .../research/cv/ProtoNet/src/__init__.py | 0 model_zoo/research/cv/ProtoNet/src/dataset.py | 129 +++++++++ .../research/cv/ProtoNet/src/parser_util.py | 118 ++++++++ .../research/cv/ProtoNet/src/protonet.py | 257 ++++++++++++++++++ model_zoo/research/cv/ProtoNet/train.py | 124 +++++++++ 17 files changed, 1397 insertions(+) create mode 100644 model_zoo/research/cv/ProtoNet/README.md create mode 100644 model_zoo/research/cv/ProtoNet/__init__.py create mode 100644 model_zoo/research/cv/ProtoNet/eval.py create mode 100644 model_zoo/research/cv/ProtoNet/export.py create mode 100644 model_zoo/research/cv/ProtoNet/model_init.py create mode 100644 model_zoo/research/cv/ProtoNet/requirements.txt create mode 100644 model_zoo/research/cv/ProtoNet/scripts/run_distribution_ascend.sh create mode 100644 model_zoo/research/cv/ProtoNet/scripts/run_standalone_eval_ascend.sh create mode 100644 model_zoo/research/cv/ProtoNet/scripts/run_standalone_train_ascend.sh create mode 100644 model_zoo/research/cv/ProtoNet/src/EvalCallBack.py create mode 100644 model_zoo/research/cv/ProtoNet/src/IterDatasetGenerator.py create mode 100644 model_zoo/research/cv/ProtoNet/src/PrototypicalLoss.py create mode 100644 model_zoo/research/cv/ProtoNet/src/__init__.py create mode 100644 model_zoo/research/cv/ProtoNet/src/dataset.py create mode 100644 model_zoo/research/cv/ProtoNet/src/parser_util.py create mode 100644 model_zoo/research/cv/ProtoNet/src/protonet.py create mode 100644 model_zoo/research/cv/ProtoNet/train.py diff --git a/model_zoo/research/cv/ProtoNet/README.md b/model_zoo/research/cv/ProtoNet/README.md new file mode 100644 index 00000000000..e55afcabf93 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/README.md @@ -0,0 +1,168 @@ +# Contents + +- [Prototypical-Network Description](#protonet-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Training](#training) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Evaluation Performance](#evaluation-performance) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [protonet-Description](#contents) + +PyTorch code for NeuralIPS 2017 paper: [Prototypical Networks for Few-shot Learning](https://arxiv.org/abs/1703.05175) + +# [Model Architecture](#contents) + +Proto-Net contains 2 parts named Encoder and Relation. The former one has 4 convolution layers, the latter one has 2 convolution layers and 2 linear layers. + +# [Dataset](#contents) + +Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. + +Dataset used: [omniglot](https://github.com/brendenlake/omniglot) + +- Dataset size 4.02M,32462 28*28 in 1622 classes + - Train 1,200 classes + - Test 422 classes +- Data format .png files + - Note Data has been processed in omniglot_resized + +- The directory structure is as follows: + +```text +└─Data + ├─raw + ├─spilts + │ vinyals + │ test.txt + │ train.txt + │ val.txt + │ trainval.txt + └─data + Alphabet_of_the_Magi + Angelic +``` + +# [Environment Requirements](#contents) + +- Hardware(Ascend) + - Prepare hardware environment with Ascend. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [Quick Start](#contents) + +After installing MindSpore via the official website, you can start training and evaluation as follows: + +```shell +# enter script dir, train ProtoNet in standalone +sh run_standalone_train_ascend.sh dataset 1 20 20 +# enter script dir, train ProtoNet in distribution +sh run_distribution_ascend.sh dataset rank_table dataset 20 +# enter script dir, evaluate ProtoNet +sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20 +``` + +## [Script and Sample Code](#contents) + +```shell +├── cv + ├── ProtoNet + ├── requirements.txt + ├── README.md // descriptions about lenet + ├── scripts + │ ├──run_standalone_train_ascend.sh // train in ascend + │ ├──run_standalone_eval_ascend.sh // evaluate in ascend + │ ├──run_distribution_ascend.sh // distribution in ascend + ├── src + │ ├──parser_util.py // parameter configuration + │ ├──dataset.py // creating dataset + │ ├──IterDatasetGenerator.py // generate dataset + │ ├──protonet.py // relationnet architecture + │ ├──PrototypicalLoss.py // loss function + ├── train.py // training script + ├── eval.py // evaluation script +``` + +## [Script Parameters](#contents) + +```python +Major parameters in train.py and config.py as follows: + +--class_num: the number of class we use in one step. +--sample_num_per_class: the number of quert data we extract from one class. +--batch_num_per_class: the number of support data we extract from one class. +--data_path: The absolute full path to the train and evaluation datasets. +--episode: Total training epochs. +--test_episode: Total testing episodes +--learning_rate: Learning rate +--device_target: Device where the code will be implemented. +--save_dir: The absolute full path to the checkpoint file saved + after training. +--data_path: Path where the dataset is saved +``` + +## [Training Process](#contents) + +### Training + +```bash +# enter script dir, train ProtoNet in standalone +sh run_standalone_train_ascend.sh dataset 1 20 20 +``` + +The model checkpoint will be saved in the current directory. + +## [Evaluation Process](#contents) + +### Evaluation + +Before running the command below, please check the checkpoint path used for evaluation. + +```bash +# enter script dir, evaluate ProtoNet +sh run_standalone_eval_ascend.sh dataset best.ckpt 1 20 +``` + +```text +Test Acc: 0.9954400658607483 Loss: 0.02102319709956646 +``` + +# [Model Description](#contents) + +## [Performance](#contents) + +### Evaluation Performance + +| Parameters | ProtoNet | +| -------------------------- | ---------------------------------------------------------- | +| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz; 192cores; Memory 755G | +| uploaded Date | 03/26/2021 (month/day/year) | +| MindSpore Version | 1.2.0 | +| Dataset | OMNIGLOT | +| Training Parameters | episode=500, class_num = 5, lr=0.001, classes_per_it_tr=60, num_support_tr=5, num_query_tr=5, classes_per_it_val=20, num_support_val=5, num_query_val=15 | +| Optimizer | Adam | +| Loss Function | Prototypicalloss | +| outputs | Accuracy | +| Loss | 0.002 | +| Speed | 215 ms/step | +| Total time | 3 h 23m (8p) | +| Checkpoint for Fine tuning | 440 KB (.ckpt file) | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/protonet | + +# [ModelZoo Homepage](#contents) + + Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/research/cv/ProtoNet/__init__.py b/model_zoo/research/cv/ProtoNet/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/cv/ProtoNet/eval.py b/model_zoo/research/cv/ProtoNet/eval.py new file mode 100644 index 00000000000..27d7cf3daa7 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/eval.py @@ -0,0 +1,71 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ProtoNet evaluation script. +""" +import os +from mindspore import dataset as ds +from mindspore import load_checkpoint +import mindspore.context as context +from src.protonet import ProtoNet +from src.parser_util import get_parser +from src.PrototypicalLoss import PrototypicalLoss +import numpy as np +from model_init import init_dataloader +from train import WithLossCell + + +def test(test_dataloader, net): + """ + test function + """ + inp = ds.GeneratorDataset(test_dataloader, column_names=['data', 'label', 'classes']) + avg_acc = list() + avg_loss = list() + for _ in range(10): + i = 0 + for batch in inp.create_dict_iterator(): + i = i + 1 + print(i) + x = batch['data'] + y = batch['label'] + classes = batch['classes'] + acc, loss = net(x, y, classes) + avg_acc.append(acc.asnumpy()) + avg_loss.append(loss.asnumpy()) + print('eval end') + avg_acc = np.mean(avg_acc) + avg_loss = np.mean(avg_loss) + print('Test Acc: {} Loss: {}'.format(avg_acc, avg_loss)) + +if __name__ == '__main__': + context.set_context(mode=context.GRAPH_MODE) + options = get_parser().parse_args() + if options.run_offline: + datapath = options.dataset_root + ckptpath = options.experiment_root + else: + import mox + mox.file.copy_parallel(src_url=options.data_url, dst_url='cache/data') + mox.file.copy_parallel(src_url=options.ckpt_url, dst_url='cache/ckpt') + datapath = 'cache/data' + ckptpath = 'cache/ckpt' + Net = ProtoNet() + loss_fn = PrototypicalLoss(options.num_support_val, options.num_query_val, + options.classes_per_it_val, is_train=False) + Net = WithLossCell(Net, loss_fn) + val_dataloader = init_dataloader(options, 'val', datapath) + load_checkpoint(os.path.join(ckptpath, 'best_ck.ckpt'), net=Net) + test(val_dataloader, Net) diff --git a/model_zoo/research/cv/ProtoNet/export.py b/model_zoo/research/cv/ProtoNet/export.py new file mode 100644 index 00000000000..69e0ab363b7 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/export.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export checkpoint file into air, onnx, mindir models""" + +import argparse +import numpy as np + +import mindspore +from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export + +from src.protonet import ProtoNet as protonet + +parser = argparse.ArgumentParser(description='MindSpore MNIST Example') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="protonet", help="output file name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") +parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", + help="device target") +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) + +if __name__ == "__main__": + + # define fusion network + network = protonet() + # load network checkpoint + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(network, param_dict) + + # export network + inputs = Tensor(np.ones([args.batch_size, 1, 28, 28]), mindspore.float32) + export(network, inputs, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/research/cv/ProtoNet/model_init.py b/model_zoo/research/cv/ProtoNet/model_init.py new file mode 100644 index 00000000000..5da3991e209 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/model_init.py @@ -0,0 +1,63 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ProtoNet model init script. +""" +import itertools +import mindspore.nn as nn +import numpy as np +from src.dataset import OmniglotDataset +from src.IterDatasetGenerator import IterDatasetGenerator + +def init_lr_scheduler(opt): + ''' + Initialize the learning rate scheduler + ''' + epochs = opt.epochs + milestone = list(itertools.takewhile(lambda n: n < epochs, itertools.count(1, opt.lr_scheduler_step))) + + lr0 = opt.learning_rate + bl = list(np.logspace(0, len(milestone)-1, len(milestone), base=opt.lr_scheduler_gamma)) + lr = [lr0*b for b in bl] + lr_epoch = nn.piecewise_constant_lr(milestone, lr) + return lr_epoch + +def init_dataset(opt, mode, path): + ''' + Initialize the dataset + ''' + dataset = OmniglotDataset(mode=mode, root=path) + n_classes = len(np.unique(dataset.y)) + if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val: + raise(Exception('There are not enough classes in the dataset in order ' + + 'to satisfy the chosen classes_per_it. Decrease the ' + + 'classes_per_it_{tr/val} option and try again.')) + return dataset + +def init_dataloader(opt, mode, path): + ''' + Initialize the dataloader + ''' + dataset = init_dataset(opt, mode, path) + if 'train' in mode: + classes_per_it = opt.classes_per_it_tr + num_samples = opt.num_support_tr + opt.num_query_tr + + else: + classes_per_it = opt.classes_per_it_val + num_samples = opt.num_support_val + opt.num_query_val + + dataloader = IterDatasetGenerator(dataset, classes_per_it, num_samples, opt.iterations) + return dataloader diff --git a/model_zoo/research/cv/ProtoNet/requirements.txt b/model_zoo/research/cv/ProtoNet/requirements.txt new file mode 100644 index 00000000000..bcf658266c8 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/requirements.txt @@ -0,0 +1,3 @@ +numpy >= 1.17.0 +tqdm>= 4.61.0 +pillow >= 8.2.0 diff --git a/model_zoo/research/cv/ProtoNet/scripts/run_distribution_ascend.sh b/model_zoo/research/cv/ProtoNet/scripts/run_distribution_ascend.sh new file mode 100644 index 00000000000..ce0977ca511 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/scripts/run_distribution_ascend.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 4 ] +then + echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [TRAIN_CLASS]" +exit 1 +fi + +if [ ! -f $1 ] +then + echo "error: RANK_TABLE_FILE=$1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +RANK_TABLE_FILE=$(realpath $1) +export RANK_TABLE_FILE +export DATA_PATH=$2 +export TRAIN_CLASS=$3 +echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" + +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$((rank_start + i)) + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp -r ./src ./train_parallel$i + cp ./train.py ./train_parallel$i + echo "start training for rank $RANK_ID, device $DEVICE_ID" + cd ./train_parallel$i ||exit + env > env.log + python train.py --data_path=$DATA_PATH \ + --device_id=$DEVICE_ID --device_target="Ascend" \ + --classes_per_it_tr=$TRAIN_CLASS > log 2>&1 & + cd .. +done diff --git a/model_zoo/research/cv/ProtoNet/scripts/run_standalone_eval_ascend.sh b/model_zoo/research/cv/ProtoNet/scripts/run_standalone_eval_ascend.sh new file mode 100644 index 00000000000..adc32b91b17 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/scripts/run_standalone_eval_ascend.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 4 ] +then + echo "Usage: sh run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [EVAL_CLASS]" +exit 1 +fi + +export DATA_PATH=$1 +export CKPT_PATH=$2 +export DEVICE_ID=$3 +export EVAL_CLASS=$4 + +python ../eval.py --dataset_root=$DATA_PATH --experiment_root=$CKPT_PATH \ + --device_id=$DEVICE_ID --device_target="Ascend" \ + --classes_per_it_val=$EVAL_CLASS > eval_log 2>&1 & diff --git a/model_zoo/research/cv/ProtoNet/scripts/run_standalone_train_ascend.sh b/model_zoo/research/cv/ProtoNet/scripts/run_standalone_train_ascend.sh new file mode 100644 index 00000000000..3b499aa5e29 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/scripts/run_standalone_train_ascend.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 4 ] +then + echo "Usage: sh run_standalone_train_ascend.sh [DATA_PATH] [DEVICE_ID] [TRAIN_CLASS] [EPOCHS]" +exit 1 +fi + +export DATA_PATH=$1 +export DEVICE_ID=$2 +export TRAIN_CLASS=$3 +export EPOCHS=$4 + +python ../train.py --dataset_root=$DATA_PATH \ + --device_id=$DEVICE_ID --device_target="Ascend" \ + --classes_per_it_tr=$TRAIN_CLASS \ + --epochs=$EPOCHS > log 2>&1 & diff --git a/model_zoo/research/cv/ProtoNet/src/EvalCallBack.py b/model_zoo/research/cv/ProtoNet/src/EvalCallBack.py new file mode 100644 index 00000000000..08a297926b8 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/src/EvalCallBack.py @@ -0,0 +1,88 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Callback for eval +""" + +import os +from mindspore.train.callback import Callback +from mindspore import save_checkpoint +import numpy as np + + +class EvalCallBack(Callback): + """ + CallBack class + """ + def __init__(self, options, net, eval_dataset, path): + self.net = net + self.eval_dataset = eval_dataset + self.path = path + self.avgacc = 0 + self.avgloss = 0 + self.bestacc = 0 + self.options = options + + + def epoch_begin(self, run_context): + """ + CallBack epoch begin + """ + cb_param = run_context.original_args() + cur_epoch = cb_param.cur_epoch_num + print('=========EPOCH {} BEGIN========='.format(cur_epoch)) + + def epoch_end(self, run_context): + """ + CallBack epoch end + """ + cb_param = run_context.original_args() + cur_epoch = cb_param.cur_epoch_num + cur_net = cb_param.network + # print(cur_net) + evalnet = self.net + self.avgacc, self.avgloss = self.eval(self.eval_dataset, evalnet) + + if self.avgacc > self.bestacc: + self.bestacc = self.avgacc + print('Epoch {}: Avg Accuracy: {}(best) Avg Loss:{}'.format(cur_epoch, self.avgacc, self.avgloss)) + best_path = os.path.join(self.path, 'best_ck.ckpt') + save_checkpoint(cur_net, best_path) + + else: + print('Epoch {}: Avg Accuracy: {} Avg Loss:{}'.format(cur_epoch, self.avgacc, self.avgloss)) + last_path = os.path.join(self.path, 'last_ck.ckpt') + save_checkpoint(cur_net, last_path) + print("Best Acc:", self.bestacc) + print('=========EPOCH {} END========='.format(cur_epoch)) + + def eval(self, inp, net): + """ + CallBack eval + """ + avg_acc = list() + avg_loss = list() + for _ in range(10): + for batch in inp.create_dict_iterator(): + x = batch['data'] + y = batch['label'] + classes = batch['classes'] + acc, loss = net(x, y, classes) + avg_acc.append(acc.asnumpy()) + avg_loss.append(loss.asnumpy()) + avg_acc = np.mean(avg_acc) + avg_loss = np.mean(avg_loss) + + return avg_acc, avg_loss diff --git a/model_zoo/research/cv/ProtoNet/src/IterDatasetGenerator.py b/model_zoo/research/cv/ProtoNet/src/IterDatasetGenerator.py new file mode 100644 index 00000000000..fd570050f48 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/src/IterDatasetGenerator.py @@ -0,0 +1,80 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +dataset iter generator script. +""" +import numpy as np +from tqdm import tqdm + + +class IterDatasetGenerator: + """ + dataloader class + """ + def __init__(self, data, classes_per_it, num_samples, iterations): + self.__iterations = iterations + self.__data = data.x + self.__labels = data.y + self.__iter = 0 + self.classes_per_it = classes_per_it + self.sample_per_class = num_samples + self.classes, self.counts = np.unique(self.__labels, return_counts=True) + self.idxs = range(len(self.__labels)) + self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan + self.numel_per_class = np.zeros_like(self.classes) + for idx, label in tqdm(enumerate(self.__labels)): + label_idx = np.argwhere(self.classes == label).item() + self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx + self.numel_per_class[label_idx] = int(self.numel_per_class[label_idx]) + 1 + + print('init end') + + + def __next__(self): + spc = self.sample_per_class + cpi = self.classes_per_it + + if self.__iter >= self.__iterations: + raise StopIteration + batch_size = spc * cpi + batch = np.random.randint(low=batch_size, high=10 * batch_size, size=(batch_size), dtype=np.int32) + c_idxs = np.random.permutation(len(self.classes))[:cpi] + for indx, c in enumerate(self.classes[c_idxs]): + index = indx*spc + ci = [c_i for c_i in range(len(self.classes)) if self.classes[c_i] == c][0] + label_idx = list(range(len(self.classes)))[ci] + sample_idxs = np.random.permutation(int(self.numel_per_class[label_idx]))[:spc] + ind = 0 + for sid in sample_idxs: + batch[index+ind] = self.indexes[label_idx][sid] + ind = ind + 1 + batch = batch[np.random.permutation(len(batch))] + data_x = [] + data_y = [] + for b in batch: + data_x.append(self.__data[b]) + data_y.append(self.__labels[b]) + self.__iter += 1 + data_y = np.asarray(data_y, np.int32) + data_class = np.asarray(np.unique(data_y), np.int32) + item = (data_x, data_y, data_class) + return item + + def __iter__(self): + self.__iter = 0 + return self + + def __len__(self): + return self.__iterations diff --git a/model_zoo/research/cv/ProtoNet/src/PrototypicalLoss.py b/model_zoo/research/cv/ProtoNet/src/PrototypicalLoss.py new file mode 100644 index 00000000000..8faa95927f5 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/src/PrototypicalLoss.py @@ -0,0 +1,131 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +loss function script. +""" +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn.loss.loss import _Loss +import mindspore as ms +import numpy as np + +class PrototypicalLoss(_Loss): + ''' + Loss class deriving from Module for the prototypical loss function defined below + ''' + def __init__(self, n_support, n_query, n_class, is_train=True): + super(PrototypicalLoss, self).__init__() + self.n_support = n_support + self.n_query = n_query + self.eq = ops.Equal() + self.sum = ops.ReduceSum(keep_dims=True) + self.log_softmax = nn.LogSoftmax(1) + self.gather = ops.GatherD() + self.squeeze = ops.Squeeze() + self.max = ops.Argmax(2) + self.cast = ops.Cast() + self.stack = ops.Stack() + self.reshape = ops.Reshape() + self.topk = ops.TopK(sorted=True) + self.expendDims = ops.ExpandDims() + self.broadcastTo = ops.BroadcastTo((100, 20, 64)) + self.pow = ops.Pow() + self.sum = ops.ReduceSum() + self.zeros = Tensor(np.zeros(200), ms.float32) + self.ones = Tensor(np.ones(200), ms.float32) + self.print = ops.Print() + self.unique = ops.Unique() + self.samples_count = 10 + self.select = ops.Select() + self.target_inds = Tensor(list(range(0, n_class)), ms.int32) + self.is_train = is_train + # self.acc_val = 0 + + def construct(self, inp, target, classes): + """ + loss construct + """ + n_classes = len(classes) + n_query = self.n_query + support_idxs = () + query_idxs = () + + for ind, _ in enumerate(classes): + class_c = classes[ind] + _, a = self.topk(self.cast(self.eq(target, class_c), ms.float32), self.n_support + self.n_query) + support_idx = self.squeeze(a[:self.n_support]) + support_idxs += (support_idx,) + query_idx = a[self.n_support:] + query_idxs += (query_idx,) + + + prototypes = () + for idx_list in support_idxs: + prototypes += (inp[idx_list].mean(0),) + prototypes = self.stack(prototypes) + + query_idxs = self.stack(query_idxs).view(-1) + query_samples = inp[query_idxs] + + dists = euclidean_dist(query_samples, prototypes) + + log_p_y = self.log_softmax(-dists) + + log_p_y = self.reshape(log_p_y, (n_classes, n_query, -1)) + + target_inds = self.target_inds.view(n_classes, 1, 1) + target_inds = ops.BroadcastTo((n_classes, n_query, 1))(target_inds) # to int64 + + loss_val = -self.squeeze(self.gather(log_p_y, 2, target_inds)).view(-1).mean() + + y_hat = self.max(log_p_y) + acc_val = self.cast(self.eq(y_hat, self.squeeze(target_inds)), ms.float32).mean() + if self.is_train: + return loss_val + return acc_val, loss_val + + def supp_idxs(self, target, c): + return self.squeeze(self.nonZero(self.eq(target, c))[:self.n_support]) + + def nonZero(self, inpbool): + out = [] + for _, inp in enumerate(inpbool): + if inp: + out.append(inp) + return Tensor(out, ms.int32) + + def acc(self): + return self.acc_val + + +def euclidean_dist(x, y): + ''' + Compute euclidean distance between two tensors + ''' + # x: N x D + # y: M x D + n = x.shape[0] + m = y.shape[0] + d = x.shape[1] + + expendDims = ops.ExpandDims() + broadcastTo = ops.BroadcastTo((n, m, d)) + pow_op = ops.Pow() + reducesum = ops.ReduceSum() + + x = broadcastTo(expendDims(x, 1)) + y = broadcastTo(expendDims(y, 0)) + return reducesum(pow_op(x-y, 2), 2) diff --git a/model_zoo/research/cv/ProtoNet/src/__init__.py b/model_zoo/research/cv/ProtoNet/src/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/cv/ProtoNet/src/dataset.py b/model_zoo/research/cv/ProtoNet/src/dataset.py new file mode 100644 index 00000000000..11bb9d52dad --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/src/dataset.py @@ -0,0 +1,129 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +dataset for ProtoNet +""" +import os +from PIL import Image +import numpy as np + +IMG_CACHE = {} + + +class OmniglotDataset(): + """ + Omniglot dataset class + """ + + splits_folder = os.path.join('splits', 'vinyals') + raw_folder = 'raw' + processed_folder = 'data' + + def __init__(self, mode='train', root='.' + os.sep + 'dataset', transform=None, target_transform=None): + self.root = root + print(self.root) + self.transform = transform + self.target_transform = target_transform + + self.classes = get_current_classes(os.path.join( + self.root, self.splits_folder, mode + '.txt')) + self.all_items = find_items(os.path.join( + self.root, self.processed_folder), self.classes) + + self.idx_classes = index_classes(self.all_items) + paths, self.y = zip(*[self.get_path_label(pl) + for pl in range(len(self))]) + self.x = map(load_img, paths, range(len(paths))) + self.x = list(self.x) + + def __getitem__(self, idx): + x = self.x[idx] + if self.transform: + x = self.transform(x) + return x, self.y[idx] + + def __len__(self): + return len(self.all_items) + + def get_path_label(self, index): + filename = self.all_items[index][0] + rot = self.all_items[index][-1] + img = str.join(os.sep, [self.all_items[index][2], filename]) + rot + target = self.idx_classes[self.all_items[index] + [1] + self.all_items[index][-1]] + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + +def find_items(root_dir, classes): + """ + function to find items + """ + retour = [] + rots = [os.sep + 'rot000', os.sep + 'rot090', os.sep + 'rot180', os.sep + 'rot270'] + for (root, _, files) in os.walk(root_dir): + for f in files: + r = root.split(os.sep) + lr = len(r) + label = r[lr - 2] + os.sep + r[lr - 1] + for rot in rots: + if label + rot in classes and (f.endswith("png")): + retour.extend([(f, label, root, rot)]) + print("== Dataset: Found %d items " % len(retour)) + return retour + + +def index_classes(items): + """ + how mach items and classes dataset have + """ + idx = {} + for i in items: + if not i[1] + i[-1] in idx: + idx[i[1] + i[-1]] = len(idx) + print("== Dataset: Found %d classes" % len(idx)) + return idx + + +def get_current_classes(fname): + """ + get current classes + """ + with open(fname) as f: + classes = f.read().replace('/', os.sep).splitlines() + return classes + + +def load_img(path, idx): + """ + function to load images + """ + path, rot = path.split(os.sep + 'rot') + if path in IMG_CACHE: + x = IMG_CACHE[path] + else: + x = Image.open(path) + IMG_CACHE[path] = x + x = x.rotate(float(rot)) + x = x.resize((28, 28)) + + shape = 1, x.size[0], x.size[1] + x = np.array(x, np.float32, copy=False) + x = 1.0 - x + x = x.T.reshape(shape) + + return x diff --git a/model_zoo/research/cv/ProtoNet/src/parser_util.py b/model_zoo/research/cv/ProtoNet/src/parser_util.py new file mode 100644 index 00000000000..906d5385bd7 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/src/parser_util.py @@ -0,0 +1,118 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ProtoNet parser_util script. +""" +import os +import argparse + +def get_parser(): + """ + ProtoNet parser_util script. + """ + parser = argparse.ArgumentParser() + + parser.add_argument('--run_offline', default=True, help='run in offline is False or True') + + parser.add_argument('--data_url', default=None, help='Location of data.') + + parser.add_argument('--train_url', default=None, help='Location of training outputs.') + + parser.add_argument('--ckpt_url', default=None, help='Location of training outputs.') + + parser.add_argument('-root', '--dataset_root', + type=str, + help='path to dataset', + default='..' + os.sep + 'dataset') + + parser.add_argument('-target', '--device_target', + type=str, + help='path to dataset', + default='Ascend') + parser.add_argument('-id', '--device_id', + type=int, + help='device is', + default=0) + + parser.add_argument('-exp', '--experiment_root', + type=str, + help='root where to store models, losses and accuracies', + default='..' + os.sep + 'output') + + parser.add_argument('-nep', '--epochs', + type=int, + help='number of epochs to train for', + default=2) + + parser.add_argument('-lr', '--learning_rate', + type=float, + help='learning rate for the model, default=0.001', + default=0.001) + + parser.add_argument('-lrS', '--lr_scheduler_step', + type=int, + help='StepLR learning rate scheduler step, default=20', + default=20) + + parser.add_argument('-lrG', '--lr_scheduler_gamma', + type=float, + help='StepLR learning rate scheduler gamma, default=0.5', + default=0.5) + + parser.add_argument('-its', '--iterations', + type=int, + help='number of episodes per epoch, default=100', + default=100) + + parser.add_argument('-cTr', '--classes_per_it_tr', + type=int, + help='number of random classes per episode for training, default=60', + default=20) + + parser.add_argument('-nsTr', '--num_support_tr', + type=int, + help='number of samples per class to use as support for training, default=5', + default=5) + + parser.add_argument('-nqTr', '--num_query_tr', + type=int, + help='number of samples per class to use as query for training, default=5', + default=5) + + parser.add_argument('-cVa', '--classes_per_it_val', + type=int, + help='number of random classes per episode for validation, default=5', + default=20) + + parser.add_argument('-nsVa', '--num_support_val', + type=int, + help='number of samples per class to use as support for validation, default=5', + default=5) + + parser.add_argument('-nqVa', '--num_query_val', + type=int, + help='number of samples per class to use as query for validation, default=15', + default=15) + + parser.add_argument('-seed', '--manual_seed', + type=int, + help='input for the manual seeds initializations', + default=7) + + parser.add_argument('--cuda', + action='store_true', + help='enables cuda') + + return parser diff --git a/model_zoo/research/cv/ProtoNet/src/protonet.py b/model_zoo/research/cv/ProtoNet/src/protonet.py new file mode 100644 index 00000000000..3f0258656dd --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/src/protonet.py @@ -0,0 +1,257 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ProtoNet. +""" +from functools import reduce +import math +import numpy as np +import mindspore.nn as nn +import mindspore as ms +import mindspore.ops as ops +from mindspore.common import initializer as init + +def _calculate_gain(nonlinearity, param=None): + r""" + Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function + param: optional parameter for the non-linear function + + Examples: + >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + +def _assignment(arr, num): + """Assign the value of `num` to `arr`.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + +def _calculate_in_and_out(arr): + """ + Calculate n_in and n_out. + + Args: + arr (Array): Input array. + + Returns: + Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. + """ + dim = len(arr.shape) + if dim < 2: + raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") + + n_in = arr.shape[1] + n_out = arr.shape[0] + + if dim > 2: + counter = reduce(lambda x, y: x * y, arr.shape[2:]) + n_in *= counter + n_out *= counter + return n_in, n_out + +def _select_fan(array, mode): + """ + select fan + """ + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_in_and_out(array) + return fan_in if mode == 'fan_in' else fan_out + +class KaimingInit(init.Initializer): + r""" + Base Class. Initialize the array with He kaiming algorithm. + + Args: + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function, recommended to use only with + ``'relu'`` or ``'leaky_relu'`` (default). + """ + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingInit, self).__init__() + self.mode = mode + self.gain = _calculate_gain(nonlinearity, a) + def _initialize(self, arr): + pass + + +class KaimingUniform(KaimingInit): + r""" + Initialize the array with He kaiming uniform algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + bound = math.sqrt(3.0) * self.gain / math.sqrt(fan) + data = np.random.uniform(-bound, bound, arr.shape) + + _assignment(arr, data) + + +class KaimingNormal(KaimingInit): + r""" + Initialize the array with He kaiming normal algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + std = self.gain / math.sqrt(fan) + data = np.random.normal(0, std, arr.shape) + + _assignment(arr, data) + + +def conv_block(in_channels, out_channels): + ''' + returns a block conv-bn-relu-pool + ''' + return nn.SequentialCell( + nn.Conv2d(in_channels, out_channels, 3, pad_mode='pad', padding=1, has_bias=True), + nn.BatchNorm2d(out_channels, momentum=0.1), + nn.ReLU(), + nn.MaxPool2d(2, 2) + ) + +class ProtoNet(nn.Cell): + ''' + Model as described in the reference paper, + source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84 + ''' + def __init__(self, x_dim=1, hid_dim=64, z_dim=64): + super(ProtoNet, self).__init__() + self.encoder = nn.SequentialCell( + conv_block(x_dim, hid_dim), + conv_block(hid_dim, hid_dim), + conv_block(hid_dim, hid_dim), + conv_block(hid_dim, z_dim), + ) + self._initialize_weights() + self.print = ops.Print() + + def construct(self, x): + x = self.encoder(x) + reshape = ops.Reshape() + x = reshape(x, (x.shape[0], -1)) + return x + + def _initialize_weights(self): + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + ms.float32)) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + cell.bias.set_data(init.initializer(init.Uniform(bound), + cell.bias.shape, + ms.float32)) +class WithLossCell(nn.Cell): + """ + Examples: + >>> net = Net() + >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) + >>> net_with_criterion = nn.WithLossCell(net, loss_fn) + >>> + >>> batch_size = 2 + >>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01) + >>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32)) + >>> + >>> output_data = net_with_criterion(data, label) + """ + + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, data, label, classes): + out = self._backbone(data) + return self._loss_fn(out, label, classes) diff --git a/model_zoo/research/cv/ProtoNet/train.py b/model_zoo/research/cv/ProtoNet/train.py new file mode 100644 index 00000000000..9ba1873d1a6 --- /dev/null +++ b/model_zoo/research/cv/ProtoNet/train.py @@ -0,0 +1,124 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ProtoNet train script. +""" +import os +import datetime +import mindspore.nn as nn +from mindspore.train import Model +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore import dataset as ds +import mindspore.context as context +from mindspore.communication.management import init +from mindspore.context import ParallelMode +from src.EvalCallBack import EvalCallBack +from src.protonet import WithLossCell +from src.PrototypicalLoss import PrototypicalLoss +from src.parser_util import get_parser +from src.protonet import ProtoNet +from model_init import init_dataloader + +local_data_url = './cache/data' +local_train_url = './cache/out' + + +def train(opt, tr_dataloader, net, loss_fn, eval_loss_fn, optim, path, val_dataloader=None): + ''' + train function + ''' + + inp = ds.GeneratorDataset(tr_dataloader, column_names=['data', 'label', 'classes']) + my_loss_cell = WithLossCell(net, loss_fn) + my_acc_cell = WithLossCell(net, eval_loss_fn) + model = Model(my_loss_cell, optimizer=optim) + + eval_data = ds.GeneratorDataset(val_dataloader, column_names=['data', 'label', 'classes']) + + eval_cb = EvalCallBack(opt, my_acc_cell, eval_data, path) + config = CheckpointConfig(save_checkpoint_steps=10, + keep_checkpoint_max=5, + saved_network=net) + ckpoint_cb = ModelCheckpoint(prefix='protonet', directory=path, config=config) + + print('==========training test==========') + starttime = datetime.datetime.now() + model.train(opt.epochs, inp, callbacks=[ckpoint_cb, eval_cb, TimeMonitor()]) + endtime = datetime.datetime.now() + print('epoch time: ', (endtime - starttime).seconds / 10, 'per step time:', (endtime - starttime).seconds / 1000) + + +def main(): + ''' + main function + ''' + global local_data_url + global local_train_url + + options = get_parser().parse_args() + + if options.run_offline: + + device_num = int(os.environ.get("DEVICE_NUM", 1)) + + if device_num > 1: + + init() + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + context.set_context(device_id=options.device_id) + local_data_url = options.dataset_root + local_train_url = options.experiment_root + if not os.path.exists(options.experiment_root): + os.makedirs(options.experiment_root) + else: + device_num = int(os.environ.get("DEVICE_NUM", 1)) + device_id = int(os.getenv("DEVICE_ID")) + + import moxing as mox + if not os.path.exists(local_train_url): + os.makedirs(local_train_url) + + context.set_context(device_id=device_id) + + if device_num > 1: + + init() + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + local_data_url = os.path.join(local_data_url, str(device_id)) + local_train_url = os.path.join(local_train_url, str(device_id)) + + mox.file.copy_parallel(src_url=options.data_url, dst_url=local_data_url) + + tr_dataloader = init_dataloader(options, 'train', local_data_url) + val_dataloader = init_dataloader(options, 'val', local_data_url) + + loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr, options.classes_per_it_tr) + eval_loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr, options.classes_per_it_val, + is_train=False) + + Net = ProtoNet() + optim = nn.Adam(params=Net.trainable_params(), learning_rate=0.001) + train(options, tr_dataloader, Net, loss_fn, eval_loss_fn, optim, local_train_url, val_dataloader) + if not options.run_offline: + mox.file.copy_parallel(src_url='./cache/out', dst_url=options.train_url) + + +if __name__ == '__main__': + context.set_context(mode=context.GRAPH_MODE) + main()