From ca49de3ce8d3d34363bb9ea7f62f203d1f446826 Mon Sep 17 00:00:00 2001 From: shuait <1415109792@qq.com> Date: Tue, 16 Mar 2021 16:01:12 +0800 Subject: [PATCH] update readme, eval.py create /research/cv/glore --- .../research/cv/Glore_resnet200/README_CN.md | 252 ++++++++ .../research/cv/Glore_resnet200/__init__.py | 0 model_zoo/research/cv/Glore_resnet200/eval.py | 85 +++ .../research/cv/Glore_resnet200/export.py | 50 ++ .../scripts/run_distribute_train.sh | 87 +++ .../cv/Glore_resnet200/scripts/run_eval.sh | 47 ++ .../scripts/run_standalone_train.sh | 47 ++ .../cv/Glore_resnet200/src/__init__.py | 0 .../research/cv/Glore_resnet200/src/config.py | 36 ++ .../cv/Glore_resnet200/src/dataset.py | 125 ++++ .../cv/Glore_resnet200/src/glore_resnet200.py | 336 ++++++++++ .../research/cv/Glore_resnet200/src/loss.py | 53 ++ .../cv/Glore_resnet200/src/lr_generator.py | 128 ++++ .../cv/Glore_resnet200/src/transform.py | 51 ++ .../cv/Glore_resnet200/src/transform_utils.py | 594 ++++++++++++++++++ .../research/cv/Glore_resnet200/train.py | 179 ++++++ 16 files changed, 2070 insertions(+) create mode 100644 model_zoo/research/cv/Glore_resnet200/README_CN.md create mode 100644 model_zoo/research/cv/Glore_resnet200/__init__.py create mode 100644 model_zoo/research/cv/Glore_resnet200/eval.py create mode 100644 model_zoo/research/cv/Glore_resnet200/export.py create mode 100644 model_zoo/research/cv/Glore_resnet200/scripts/run_distribute_train.sh create mode 100644 model_zoo/research/cv/Glore_resnet200/scripts/run_eval.sh create mode 100644 model_zoo/research/cv/Glore_resnet200/scripts/run_standalone_train.sh create mode 100644 model_zoo/research/cv/Glore_resnet200/src/__init__.py create mode 100644 model_zoo/research/cv/Glore_resnet200/src/config.py create mode 100644 model_zoo/research/cv/Glore_resnet200/src/dataset.py create mode 100644 model_zoo/research/cv/Glore_resnet200/src/glore_resnet200.py create mode 100644 model_zoo/research/cv/Glore_resnet200/src/loss.py create mode 100644 model_zoo/research/cv/Glore_resnet200/src/lr_generator.py create mode 100644 model_zoo/research/cv/Glore_resnet200/src/transform.py create mode 100644 model_zoo/research/cv/Glore_resnet200/src/transform_utils.py create mode 100644 model_zoo/research/cv/Glore_resnet200/train.py diff --git a/model_zoo/research/cv/Glore_resnet200/README_CN.md b/model_zoo/research/cv/Glore_resnet200/README_CN.md new file mode 100644 index 00000000000..200ba7a29a5 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/README_CN.md @@ -0,0 +1,252 @@ + +# 目录 + + + +- [Glore_resnet200描述](#Glore_resnet200描述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [特性](#特性) + - [混合精度](#混合精度) +- [环境要求](#环境要求) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [训练结果](#训练结果) + - [推理过程](#推理过程) + - [推理结果](#推理结果) +- [模型描述](#模型描述) + - [性能](#性能) + - [训练性能](#训练性能) + - [Imagenet2012上的Glore_resnet200](#Imagenet2012上的Glore_resnet200) + - [推理性能](#推理性能) + - [Imagenet2012上的Glore_resnet200](#Imagenet2012上的Glore_resnet200) + - [使用流程](#使用流程) + - [推理](#推理) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#ModelZoo主页) + + + +# Glore_resnet200描述 + +## 概述 + +卷积神经网络擅长提取局部关系,但是在处理全局上的区域间关系时显得低效,且需要堆叠很多层才可能完成,而在区域之间进行全局建模和推理对很多计算机视觉任务有益。为了进行全局推理,facebook research、新加坡国立大学和360 AI研究所提出了基于图的全局推理模块-Global Reasoning Unit,可以被插入到很多任务的网络模型中。glore_res200是在ResNet200的Stage2, Stage3中分别均匀地插入了2和3个全局推理模块的用于图像分类任务的网络模型。 + +## 论文 + +1.[论文](https://arxiv.org/abs/1811.12814):Yunpeng Chenyz, Marcus Rohrbachy, Zhicheng Yany, Shuicheng Yanz, Jiashi Fengz, Yannis Kalantidisy + +# 模型架构 + +网络模型的backbone是ResNet200, 在Stage2, Stage3中分别均匀地插入了了2个和3个全局推理模块。全局推理模块在Stage2和Stage 3中插入方式相同. + +# 数据集 + +使用的数据集:[ImageNet2012](http://www.image-net.org/) + +- 数据集大小:共1000个类、224*224彩色图像 + - 训练集:共1,281,167张图像 + - 测试集:共50,000张图像 +- 数据格式:JPEG + - 注:数据在dataset.py中处理。 +- 下载数据集,目录结构如下: + +```text +└─dataset + ├─train # 训练数据集 + └─val # 评估数据集 +``` + +# 特性 + +## 混合精度 + +采用[混合精度](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。 +以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。 + +# 环境要求 + +- 硬件(Ascend) + - 准备Ascend处理器搭建硬件环境。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install/en) +- 如需查看详情,请参见如下资源: + - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) + +# 快速入门 + +通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估: + +- Ascend处理器环境运行 + +```python +# 分布式训练 +用法:bash run_distribute_train.sh [DATASET_PATH] [RANK_SIZE] + +# 单机训练 +用法:bash run_standalone_train.sh [DATASET_PATH] [DEVICE_ID] + +# 运行评估示例 +用法:bash run_eval.sh [DATASET_PATH] [DEVICE_ID] [CHECKPOINT_PATH] +``` + + 对于分布式训练,需要提前创建JSON格式的hccl配置文件。 + + 请遵循以下链接中的说明: + + + +# 脚本说明 + +## 脚本及样例代码 + +```shell +. +└──Glore_resnet200 + ├── README.md + ├── script + ├── run_distribute_train_gpu.sh # 启动Ascend分布式训练(8卡) + ├── run_eval.sh # 启动Ascend评估(单卡) + └── run_standalone_train.sh # 启动Ascend单机训练(单卡) + ├── src + ├── _init_.py + ├── config.py #参数配置 + ├── dataset.py # 加载数据集 + ├── lr_generator.py # 学习率策略 + ├── glore_resnet200.py # glore_resnet200网络 + ├── transform.py # 数据增强 + └── transform_utils.py # 数据增强 + ├── eval.py # 推理脚本 + ├── export.py # 将checkpoint导出 + └── train.py # 训练脚本 +``` + +## 脚本参数 + +- 配置Glore_resnet200在ImageNet2012数据集参数。 + +```text +"class_num":1000, # 数据集类数 +"batch_size":128, # 输入张量的批次大小 +"loss_scale":1024, # 损失等级 +"momentum":0.08, # 动量优化器 +"weight_decay":0.0002, # 权重衰减 +"epoch_size":150, # 此值仅适用于训练;应用于推理时固定为1 +"pretrain_epoch_size":0, # 加载预训练检查点之前已经训练好的模型的周期大小;实际训练周期大小等于epoch_size减去pretrain_epoch_size +"save_checkpoint":True, # 是否保存检查点 +"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存 +"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点 +"save_checkpoint_path":"./", # 检查点相对于执行路径的保存路径 +"warmup_epochs":0, # 热身周期数 +"lr_decay_mode":"poly", # 用于生成学习率的衰减模式 +"lr_init":0.1, # 初始学习率 +"lr_max":0.4, # 最大学习率 +"lr_end":0.0, # 最小学习率 +``` + +更多配置细节请参考脚本`config.py`。 + +## 训练过程 + +```text +# 分布式训练 +用法:bash run_distribute_train.sh [DATASET_PATH] [RANK_SIZE] + +# 单机训练 +用法:bash run_standalone_train.sh [DATASET_PATH] [DEVICE_ID] + +# 运行评估示例 +用法:bash run_eval.sh [DATASET_PATH] [DEVICE_ID] [CHECKPOINT_PATH] +``` + +分布式训练需要提前创建JSON格式的HCCL配置文件。 + +具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。 + +训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。 + +## 训练结果 + +- 使用ImageNet2012数据集训练Glore_resnet200(8 pcs) + +```text +# 分布式训练结果(8P) +epoch:1 step:1251, loss is 6.0563216 +epoch:2 step:1251, loss is 5.3812423 +epoch:3 step:1251, loss is 4.782114 +epoch:4 step:1251, loss is 4.4079633 +epoch:5 step:1251, loss is 4.080069 +... +``` + +## 推理过程 + +```bash +# 评估 +Usage: bash run_eval.sh [DATASET_PATH] [DEVICE_ID] [CHECKPOINT_PATH] +``` + +```bash +# 评估示例 +bash run_eval.sh ~/Imagenet 0 ~/glore_resnet200-150_1251.ckpt +``` + +## 推理结果 + +```text +result:{'top_1 acc':0.7974158653846154} +``` + +# 模型描述 + +## 性能 + +### 训练性能 + +#### ImageNet2012上的Glore_resnet200 + +| 参数 | Ascend 910 +| -------------------------- | -------------------------------------- | +| 模型版本 | Glore_resnet200 +| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:2048G;系统 Euler2.8| +| 上传日期 | 2021-03-34 | +| MindSpore版本 | 1.1.1-c76 | +| 数据集 | ImageNet2012 | +| 训练参数 | epoch=150, steps per epoch=1251, batch_size = 128 | +| 优化器 | NAG | +| 损失函数 | SoftmaxCrossEntropyExpand | +| 输出 | 概率 | +| 损失 |0.7068262 | +| 速度 | 630.343毫秒/步(8卡) | +| 总时长 | 33时45分钟 | +| 参数(M) | 70.6 | +| 微调检查点| 807.57M(.ckpt文件) | +| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/Glore_resnet200) | + +### 推理性能 + +#### ImageNet2012上的Glore_resnet200 + +| 参数 | Ascend | +| ------------------- | --------------------------- | +| 模型版本 | Inception V1 | +| 资源 | Ascend 910 | +| 上传日期 | 2021-3-24 | +| MindSpore版本 | 1.1.1-c76 | +| 数据集 | 12万张图像 | +| batch_size | 128 | +| 输出 | 概率 | +| 准确性 | 8卡: 80.23% | + +# 随机情况说明 + +transform_utils.py中使用数据增强时采用了随机选择策略,train.py中使用了随机种子。 + +# ModelZoo主页 + + 请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo) \ No newline at end of file diff --git a/model_zoo/research/cv/Glore_resnet200/__init__.py b/model_zoo/research/cv/Glore_resnet200/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/cv/Glore_resnet200/eval.py b/model_zoo/research/cv/Glore_resnet200/eval.py new file mode 100644 index 00000000000..bdff20b81c3 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/eval.py @@ -0,0 +1,85 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############test glore_resnet200 example on Imagenet2012################# +python eval.py +""" +import random +import argparse +import ast +import numpy as np +from mindspore import context +from mindspore import dataset as de +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.glore_resnet200 import glore_resnet200 +from src.dataset import create_dataset_ImageNet as ImageNet +from src.loss import SoftmaxCrossEntropyExpand +from src.config import config + +parser = argparse.ArgumentParser(description='Image classification with glore_resnet200') +parser.add_argument('--use_glore', type=ast.literal_eval, default=True, help='Enable GloreUnit') +parser.add_argument('--data_url', type=str, default=None, help='Dataset path') +parser.add_argument('--train_url', type=str, help='Train output in modelarts') +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +parser.add_argument('--device_id', type=int, default=0) +parser.add_argument('--ckpt_path', type=str, default=None) +parser.add_argument('--isModelArts', type=ast.literal_eval, default=True) +parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train') +args_opt = parser.parse_args() + +if args_opt.isModelArts: + import moxing as mox + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + + +if __name__ == '__main__': + target = args_opt.device_target + # init context + device_id = args_opt.device_id + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False, + device_id=device_id) + + # dataset + eval_dataset_path = args_opt.data_url + if args_opt.isModelArts: + mox.file.copy_parallel(src_url=args_opt.data_url, dst_url='/cache/dataset') + eval_dataset_path = '/cache/dataset/' + predict_data = ImageNet(dataset_path=eval_dataset_path, + do_train=False, + repeat_num=1, + batch_size=config.batch_size, + target=target) + step_size = predict_data.get_dataset_size() + if step_size == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + + # define net + net = glore_resnet200(class_num=config.class_num, use_glore=args_opt.use_glore) + + # load checkpoint + param_dict = load_checkpoint(args_opt.ckpt_path) + load_param_into_net(net, param_dict) + + # define loss, model + loss = SoftmaxCrossEntropyExpand(sparse=True) + model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + print("============== Starting Testing ==============") + acc = model.eval(predict_data) + print("==============Acc: {} ==============".format(acc)) diff --git a/model_zoo/research/cv/Glore_resnet200/export.py b/model_zoo/research/cv/Glore_resnet200/export.py new file mode 100644 index 00000000000..84df07e4253 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/export.py @@ -0,0 +1,50 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into air, onnx, mindir models################# +python export.py +""" +import argparse +import numpy as np + +from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context +import mindspore.common.dtype as ms + +from src.config import config +from src.glore_resnet200 import glore_resnet200 + +parser = argparse.ArgumentParser(description='Classification') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--file_name", type=str, default="glore_resnet200", help="output file name.") +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') +parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", + help="device target") +parser.add_argument("--ckpt_path", type=str, default=None) + +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) + +if __name__ == '__main__': + net = glore_resnet200(class_num=config.class_num) + assert args.ckpt_path is not None, "arg.ckpt_path is None." + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.ones([args.batch_size, 3, 224, 224]), ms.float32) + export(net, input_arr, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/research/cv/Glore_resnet200/scripts/run_distribute_train.sh b/model_zoo/research/cv/Glore_resnet200/scripts/run_distribute_train.sh new file mode 100644 index 00000000000..d920cd1820c --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/scripts/run_distribute_train.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_distribute_train.sh DATA_PATH RANK_SIZE" +echo "For example: bash run_distribute_train.sh /path/dataset 8" +echo "It is better to use the absolute path." +echo "==============================================================================================================" +set -e +DATA_PATH=$1 +export DATA_PATH=${DATA_PATH} +RANK_SIZE=$2 + +EXEC_PATH=$(pwd) + +echo "$EXEC_PATH" + +test_dist_8pcs() +{ + export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json + export RANK_SIZE=8 +} + +test_dist_2pcs() +{ + export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json + export RANK_SIZE=2 +} + +test_dist_${RANK_SIZE}pcs + +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +for((i=1;i<${RANK_SIZE};i++)) +do + rm -rf device$i + mkdir device$i + cd ./device$i + mkdir src + cd ../ + cp ../*.py ./device$i + cp ../src/*.py ./device$i/src + cd ./device$i + export DEVICE_ID=$i + export RANK_ID=$i + echo "start training for device $i" + env > env$i.log + python3 train.py --data_url $1 --isModelArts False --run_distribute True > train$i.log 2>&1 & + echo "$i finish" + cd ../ +done +rm -rf device0 +mkdir device0 +cd ./device0 +mkdir src +cd ../ +cp ../*.py ./device0 +cp ../src/*.py ./device0/src +cd ./device0 +export DEVICE_ID=0 +export RANK_ID=0 +echo "start training for device 0" +env > env0.log +python3 train.py --data_url $1 --isModelArts False --run_distribute True > train0.log 2>&1 + +if [ $? -eq 0 ];then + echo "training success" +else + echo "training failed" + exit 2 +fi +echo "finish" +cd ../ diff --git a/model_zoo/research/cv/Glore_resnet200/scripts/run_eval.sh b/model_zoo/research/cv/Glore_resnet200/scripts/run_eval.sh new file mode 100644 index 00000000000..6ee62ba2c98 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/scripts/run_eval.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_eval.sh DATA_PATH DEVICE_ID CKPT_PATH" +echo "For example: bash run_eval.sh /path/dataset 0 /path/ckpt" +echo "It is better to use the absolute path." +echo "==============================================================================================================" +set -e +DATA_PATH=$1 +DEVICE_ID=$2 +export DATA_PATH=${DATA_PATH} + +EXEC_PATH=$(pwd) + +echo "$EXEC_PATH" + +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +cd ../ +export DEVICE_ID=$2 +export RANK_ID=$2 +env > env0.log +python3 eval.py --data_url $1 --isModelArts False --device_id $2 --ckpt_path $3> eval.log 2>&1 + +if [ $? -eq 0 ];then + echo "testing success" +else + echo "testing failed" + exit 2 +fi +echo "finish" +cd ../ diff --git a/model_zoo/research/cv/Glore_resnet200/scripts/run_standalone_train.sh b/model_zoo/research/cv/Glore_resnet200/scripts/run_standalone_train.sh new file mode 100644 index 00000000000..a60db989650 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/scripts/run_standalone_train.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_standalone_train.sh DATA_PATH DEVICE_ID" +echo "For example: bash run_standalone_train.sh /path/dataset 0" +echo "It is better to use the absolute path." +echo "==============================================================================================================" +set -e +DATA_PATH=$1 +DEVICE_ID=$2 +export DATA_PATH=${DATA_PATH} + +EXEC_PATH=$(pwd) + +echo "$EXEC_PATH" + +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +cd ../ +export DEVICE_ID=$2 +export RANK_ID=$2 +env > env0.log +python3 train.py --data_url $1 --isModelArts False --run_distribute False --device_id $2 > train.log 2>&1 + +if [ $? -eq 0 ];then + echo "training success" +else + echo "training failed" + exit 2 +fi +echo "finish" +cd ../ diff --git a/model_zoo/research/cv/Glore_resnet200/src/__init__.py b/model_zoo/research/cv/Glore_resnet200/src/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/cv/Glore_resnet200/src/config.py b/model_zoo/research/cv/Glore_resnet200/src/config.py new file mode 100644 index 00000000000..8cab7dad0e7 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/src/config.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py +""" +from easydict import EasyDict +config = EasyDict({ + "class_num": 1000, + "batch_size": 128, + "loss_scale": 1024, + "momentum": 0.08, + "weight_decay": 0.0002, + "epoch_size": 150, + "pretrain_epoch_size": 0, + "save_checkpoint": True, + "save_checkpoint_epochs": 5, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./", + "warmup_epochs": 0, + "lr_decay_mode": "poly", + "lr_init": 0.1, + "lr_end": 0, + "lr_max": 0.4 +}) diff --git a/model_zoo/research/cv/Glore_resnet200/src/dataset.py b/model_zoo/research/cv/Glore_resnet200/src/dataset.py new file mode 100644 index 00000000000..6f1046c3fdd --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/src/dataset.py @@ -0,0 +1,125 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.dataset.vision import Inter +from src.transform import RandAugment + + +def create_dataset_ImageNet(dataset_path, do_train, use_randaugment=False, repeat_num=1, batch_size=32, + target="Ascend"): + """ + create a train or eval imagenet2012 dataset for resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + use_randaugment(bool): enable randAugment. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + init("nccl") + rank_id = get_rank() + device_num = get_group_size() + + if device_num == 1: + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + image_size = 224 + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + + if do_train: + if use_randaugment: + trans = [ + C.Decode(), + C.RandomResizedCrop(size=(image_size, image_size), + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation=Inter.BICUBIC), + C.RandomHorizontalFlip(prob=0.5), + ] + else: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + else: + use_randaugment = False + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) + ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) + + # apply batch operations + if use_randaugment: + efficient_rand_augment = RandAugment() + ds = ds.batch(batch_size, + per_batch_map=efficient_rand_augment, + input_columns=['image', 'label'], + num_parallel_workers=2, + drop_remainder=True) + else: + ds = ds.batch(batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + rank_size = int(os.environ.get("RANK_SIZE")) + rank_id = int(os.environ.get("RANK_ID")) + else: + rank_size = 1 + rank_id = 0 + + return rank_size, rank_id diff --git a/model_zoo/research/cv/Glore_resnet200/src/glore_resnet200.py b/model_zoo/research/cv/Glore_resnet200/src/glore_resnet200.py new file mode 100644 index 00000000000..62931a53d9d --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/src/glore_resnet200.py @@ -0,0 +1,336 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""glore_resnet200""" +from collections import OrderedDict +import numpy as np +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor + + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 3, 3) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, has_bias=False) + + +def _conv1x1(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 1, 1) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight, has_bias=False) + + +def _conv7x7(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 7, 7) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight, has_bias=False) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.92, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.92, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel): + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class BN_AC_Conv(nn.Cell): + """ + Basic convolution block. + """ + def __init__(self, + in_channel, + out_channel, + kernel=1, + pad=0, + pad_mode='same', + stride=1, + groups=1, + has_bias=False): + super(BN_AC_Conv, self).__init__() + self.bn = _bn(in_channel) + self.relu = nn.ReLU() + self.conv = nn.Conv2d(in_channel, out_channel, + pad_mode=pad_mode, + padding=pad, + kernel_size=kernel, + stride=stride, + has_bias=has_bias, + group=groups) + + def construct(self, x): + out = self.bn(x) + out = self.relu(out) + out = self.conv(out) + return out + + +class GCN(nn.Cell): + """ + Graph convolution unit (single layer) + """ + + def __init__(self, num_state, num_mode, bias=False): + super(GCN, self).__init__() + # self.relu1 = nn.ReLU() + self.conv1 = nn.Conv1d(num_mode, num_mode, kernel_size=1) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, has_bias=bias) + self.transpose = ops.Transpose() + self.add = P.TensorAdd() + + def construct(self, x): + """construct""" + identity = x + # (n, num_state, num_node) -> (n, num_node, num_state) + # -> (n, num_state, num_node) + out = self.transpose(x, (0, 2, 1)) + # out = self.relu1(out) + out = self.conv1(out) + out = self.transpose(out, (0, 2, 1)) + out = self.add(out, identity) + out = self.relu2(out) + out = self.conv2(out) + return out + + +class GloreUnit(nn.Cell): + """ + Graph-based Global Reasoning Unit + Parameter: + 'normalize' is not necessary if the input size is fixed + Args: + num_in: Input channel + num_mid: + """ + + def __init__(self, num_in, num_mid, + normalize=False): + super(GloreUnit, self).__init__() + self.normalize = normalize + self.num_s = int(2 * num_mid) # 512 num_in = 1024 + self.num_n = int(1 * num_mid) # 256 + # reduce dim + self.conv_state = nn.SequentialCell([_bn(num_in), + nn.ReLU(), + _conv1x1(num_in, self.num_s, stride=1)]) + # projection map + self.conv_proj = nn.SequentialCell([_bn(num_in), + nn.ReLU(), + _conv1x1(num_in, self.num_n, stride=1)]) + + self.gcn = GCN(num_state=self.num_s, num_mode=self.num_n) + + self.conv_extend = nn.SequentialCell([_bn_last(self.num_s), + nn.ReLU(), + _conv1x1(self.num_s, num_in, stride=1)]) + + self.reshape = ops.Reshape() + self.matmul = ops.BatchMatMul() + self.transpose = ops.Transpose() + self.add = P.TensorAdd() + self.cast = P.Cast() + + def construct(self, x): + """construct""" + n = x.shape[0] + identity = x + # (n, num_in, h, w) --> (n, num_state, h, w) + # --> (n, num_state, h*w) + x_conv_state = self.conv_state(x) + x_state_reshaped = self.reshape(x_conv_state, (n, self.num_s, -1)) + + # (n, num_in, h, w) --> (n, num_node, h, w) + # --> (n, num_node, h*w) + x_conv_proj = self.conv_proj(x) + x_proj_reshaped = self.reshape(x_conv_proj, (n, self.num_n, -1)) + + # (n, num_in, h, w) --> (n, num_node, h, w) + # --> (n, num_node, h*w) + x_rproj_reshaped = x_proj_reshaped + + # projection: coordinate space -> interaction space + # (n, num_state, h*w) x (n, num_node, h*w)T --> (n, num_state, num_node) + x_proj_reshaped = self.transpose(x_proj_reshaped, (0, 2, 1)) + + # 提高速度 + x_state_reshaped_fp16 = self.cast(x_state_reshaped, mstype.float16) + x_proj_reshaped_fp16 = self.cast(x_proj_reshaped, mstype.float16) + x_n_state_fp16 = self.matmul(x_state_reshaped_fp16, x_proj_reshaped_fp16) + x_n_state = self.cast(x_n_state_fp16, mstype.float32) + + if self.normalize: + x_n_state = x_n_state * (1. / x_state_reshaped.shape[2]) + + # reasoning: (n, num_state, num_node) -> (n, num_state, num_node) + x_n_rel = self.gcn(x_n_state) + + # reverse projection: interaction space -> coordinate space + # (n, num_state, num_node) x (n, num_node, h*w) --> (n, num_state, h*w) + x_n_rel_fp16 = self.cast(x_n_rel, mstype.float16) + x_rproj_reshaped_fp16 = self.cast(x_rproj_reshaped, mstype.float16) + x_state_reshaped_fp16 = self.matmul(x_n_rel_fp16, x_rproj_reshaped_fp16) + x_state_reshaped = self.cast(x_state_reshaped_fp16, mstype.float32) + + # (n, num_state, h*w) --> (n, num_state, h, w) + x_state = self.reshape(x_state_reshaped, (n, self.num_s, identity.shape[2], identity.shape[3])) + + # (n, num_state, h, w) -> (n, num_in, h, w) + x_conv_extend = self.conv_extend(x_state) + out = self.add(x_conv_extend, identity) + return out + + +class Residual_Unit(nn.Cell): + """ + Residual unit used in Resnet + """ + def __init__(self, + in_channel, + mid_channel, + out_channel, + groups=1, + stride=1, + first_block=False): + super(Residual_Unit, self).__init__() + self.first_block = first_block + self.BN_AC_Conv1 = BN_AC_Conv(in_channel, mid_channel, kernel=1, pad=0) + self.BN_AC_Conv2 = BN_AC_Conv(mid_channel, mid_channel, kernel=3, pad_mode='pad', pad=1, stride=stride, + groups=groups) + self.BN_AC_Conv3 = BN_AC_Conv(mid_channel, out_channel, kernel=1, pad=0) + if self.first_block: + self.BN_AC_Conv_w = BN_AC_Conv(in_channel, out_channel, kernel=1, pad=0, stride=stride) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + out = self.BN_AC_Conv1(x) + out = self.BN_AC_Conv2(out) + out = self.BN_AC_Conv3(out) + if self.first_block: + identity = self.BN_AC_Conv_w(identity) + + out = self.add(out, identity) + return out + + +class ResNet(nn.Cell): + """ + Resnet architecture + """ + def __init__(self, + layer_nums, + num_classes, + use_glore=False): + super(ResNet, self).__init__() + self.layer1 = nn.SequentialCell(OrderedDict([ + ('conv', _conv7x7(3, 64, stride=2)), + ('bn', _bn(64),), + ('relu', nn.ReLU(),), + ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")) + ])) + + num_in = [64, 256, 512, 1024] + num_mid = [64, 128, 256, 512] + num_out = [256, 512, 1024, 2048] + self.layer2 = nn.SequentialCell(OrderedDict([ + ("Residual_Unit{}".format(i), Residual_Unit(in_channel=(num_in[0] if i == 1 else num_out[0]), + mid_channel=num_mid[0], + out_channel=num_out[0], + stride=1, + first_block=(i == 1))) for i in range(1, layer_nums[0] + 1) + ])) + + blocks_layer3 = [] + for i in range(1, layer_nums[1] + 1): + blocks_layer3.append( + ("Residual_Unit{}".format(i), Residual_Unit(in_channel=(num_in[1] if i == 1 else num_out[1]), + mid_channel=num_mid[1], + out_channel=num_out[1], + stride=(2 if i == 1 else 1), + first_block=(i == 1)))) + if use_glore and i in [12, 18]: + blocks_layer3.append(("Residual_Unit{}_GloreUnit".format(i), GloreUnit(num_out[1], num_mid[1]))) + self.layer3 = nn.SequentialCell(OrderedDict(blocks_layer3)) + + blocks_layer4 = [] + for i in range(1, layer_nums[2] + 1): + blocks_layer4.append( + ("Residual_Unit{}".format(i), Residual_Unit(in_channel=(num_in[2] if i == 1 else num_out[2]), + mid_channel=num_mid[2], + out_channel=num_out[2], + stride=(2 if i == 1 else 1), + first_block=(i == 1)))) + if use_glore and i in [16, 24, 32]: + blocks_layer4.append(("Residual_Unit{}_GloreUnit".format(i), GloreUnit(num_out[2], num_mid[2]))) + self.layer4 = nn.SequentialCell(OrderedDict(blocks_layer4)) + + self.layer5 = nn.SequentialCell(OrderedDict([ + ("Residual_Unit{}".format(i), Residual_Unit(in_channel=(num_in[3] if i == 1 else num_out[3]), + mid_channel=num_mid[3], + out_channel=num_out[3], + stride=(2 if i == 1 else 1), + first_block=(i == 1))) for i in range(1, layer_nums[3] + 1) + ])) + + self.tail = nn.SequentialCell(OrderedDict([ + ('bn', _bn(num_out[3])), + ('relu', nn.ReLU()) + ])) + + # self.globalpool = nn.AvgPool2d(kernel_size=7, stride=1, pad_mode='same') + self.mean = ops.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.classifier = _fc(num_out[3], num_classes) + self.print = ops.Print() + + def construct(self, x): + """construct""" + c1 = self.layer1(x) + c2 = self.layer2(c1) + c3 = self.layer3(c2) + c4 = self.layer4(c3) + c5 = self.layer5(c4) + + out = self.tail(c5) + # out = self.globalpool(out) + out = self.mean(out, (2, 3)) + out = self.flatten(out) + out = self.classifier(out) + return out + + +def glore_resnet200(class_num=1000, use_glore=True): + return ResNet(layer_nums=[3, 24, 36, 3], + num_classes=class_num, + use_glore=use_glore) diff --git a/model_zoo/research/cv/Glore_resnet200/src/loss.py b/model_zoo/research/cv/Glore_resnet200/src/loss.py new file mode 100644 index 00000000000..ff9c650cb6d --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/src/loss.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network""" +from mindspore import dtype as mstype +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.ops as ops + + +class SoftmaxCrossEntropyExpand(nn.Cell): # pylint: disable=missing-docstring + def __init__(self, sparse=False): + super(SoftmaxCrossEntropyExpand, self).__init__() + self.exp = ops.Exp() + self.sum = ops.ReduceSum(keep_dims=True) + self.onehot = ops.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.div = ops.RealDiv() + self.log = ops.Log() + self.sum_cross_entropy = ops.ReduceSum(keep_dims=False) + self.mul = ops.Mul() + self.mul2 = ops.Mul() + self.mean = ops.ReduceMean(keep_dims=False) + self.sparse = sparse + self.max = ops.ReduceMax(keep_dims=True) + self.sub = ops.Sub() + self.eps = Tensor(1e-24, mstype.float32) + + def construct(self, logit, label): # pylint: disable=missing-docstring + logit_max = self.max(logit, -1) + exp = self.exp(self.sub(logit, logit_max)) + exp_sum = self.sum(exp, -1) + softmax_result = self.div(exp, exp_sum) + if self.sparse: + label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value) + + softmax_result_log = self.log(softmax_result + self.eps) + loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1) + loss = self.mul2(ops.scalar_to_array(-1.0), loss) + loss = self.mean(loss, -1) + return loss diff --git a/model_zoo/research/cv/Glore_resnet200/src/lr_generator.py b/model_zoo/research/cv/Glore_resnet200/src/lr_generator.py new file mode 100644 index 00000000000..ee35b352640 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/src/lr_generator.py @@ -0,0 +1,128 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly or default + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if lr_decay_mode == 'steps': + decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + for i in range(total_steps): + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + else: + lr = lr_max * 0.001 + lr_each_step.append(lr) + elif lr_decay_mode == 'poly': + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max) * base * base + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + elif lr_decay_mode == 'cosine': + decay_steps = total_steps - warmup_steps + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_max * decayed + lr_each_step.append(lr) + else: + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + + lr_each_step = np.array(lr_each_step).astype(np.float32) + + return lr_each_step + + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0): + """ + generate learning rate array with cosine + + Args: + lr(float): base learning rate + steps_per_epoch(int): steps size of one epoch + warmup_epochs(int): number of warmup epochs + max_epoch(int): total epochs of training + global_step(int): the current start index of lr array + Returns: + np.array, learning rate array + """ + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + decay_steps = total_steps - warmup_steps + + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = base_lr * decayed + lr_each_step.append(lr) + + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[global_step:] + return learning_rate diff --git a/model_zoo/research/cv/Glore_resnet200/src/transform.py b/model_zoo/research/cv/Glore_resnet200/src/transform.py new file mode 100644 index 00000000000..cba6ea73a65 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/src/transform.py @@ -0,0 +1,51 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +random augment class +""" +import numpy as np +import mindspore.dataset.vision.py_transforms as P +from src import transform_utils + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +class RandAugment: + """ + random augment + """ + # config_str belongs to str + # hparams belongs to dict + def __init__(self, config_str="rand-m9-mstd0.5", hparams=None): + hparams = hparams if hparams is not None else {} + self.config_str = config_str + self.hparams = hparams + + def __call__(self, imgs, labels, batchInfo): + # assert the imgs object are pil_images + ret_imgs = [] + ret_labels = [] + py_to_pil_op = P.ToPIL() + to_tensor = P.ToTensor() + normalize_op = P.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + rand_augment_ops = transform_utils.rand_augment_transform(self.config_str, self.hparams) + for i, image in enumerate(imgs): + img_pil = py_to_pil_op(image) + img_pil = rand_augment_ops(img_pil) + img_array = to_tensor(img_pil) + img_array = normalize_op(img_array) + ret_imgs.append(img_array) + ret_labels.append(labels[i]) + return np.array(ret_imgs), np.array(ret_labels) diff --git a/model_zoo/research/cv/Glore_resnet200/src/transform_utils.py b/model_zoo/research/cv/Glore_resnet200/src/transform_utils.py new file mode 100644 index 00000000000..c8a1301fca1 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/src/transform_utils.py @@ -0,0 +1,594 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +augment operation +""" +import math +import random +import re +import numpy as np +import PIL +from PIL import Image, ImageEnhance, ImageOps + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) +_FILL = (128, 128, 128) +_MAX_LEVEL = 10. +_HPARAMS_DEFAULT = dict(translate_const=250, img_mean=_FILL) +_RAND_TRANSFORMS = [ + 'Distort', + 'Zoom', + 'Blur', + 'Skew', + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeTpu', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', +] +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'PosterizeTpu': 0, + 'Invert': 0, + 'Distort': 0, + 'Zoom': 0, + 'Blur': 0, + 'Skew': 0, +} + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + +# define all kinds of functions + + +def _randomly_negate(v): + return -v if random.random() > 0.5 else v + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + """ + rotate operation + """ + kwargs_new = kwargs + kwargs_new.pop('resample') + kwargs_new['resample'] = Image.BICUBIC + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs_new) + if _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs_new) + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + """ + add solarize operation + """ + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, _hparams): + # range [-0.45, 0.45] + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return (level,) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _posterize_research_level_to_arg(level, _hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image' + return (4 - int((level / _MAX_LEVEL) * 4),) + + +def _posterize_tpu_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + return (int((level / _MAX_LEVEL) * 4),) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +def _distort_level_to_arg(level, _hparams): + return (int((level / _MAX_LEVEL) * 10 + 10),) + + +def _zoom_level_to_arg(level, _hparams): + return ((level / _MAX_LEVEL) * 0.4,) + + +def _blur_level_to_arg(level, _hparams): + level = (level / _MAX_LEVEL) * 0.5 + level = _randomly_negate(level) + return (level,) + + +def _skew_level_to_arg(level, _hparams): + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def distort(img, v, **__): + """ + distort operation + """ + w, h = img.size + horizontal_tiles = int(0.1 * v) + vertical_tiles = int(0.1 * v) + + width_of_square = int(math.floor(w / float(horizontal_tiles))) + height_of_square = int(math.floor(h / float(vertical_tiles))) + width_of_last_square = w - (width_of_square * (horizontal_tiles - 1)) + height_of_last_square = h - (height_of_square * (vertical_tiles - 1)) + dimensions = [] + + for vertical_tile in range(vertical_tiles): + for horizontal_tile in range(horizontal_tiles): + if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_last_square + (horizontal_tile * width_of_square), + height_of_last_square + (height_of_square * vertical_tile)]) + elif vertical_tile == (vertical_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_square + (horizontal_tile * width_of_square), + height_of_last_square + (height_of_square * vertical_tile)]) + elif horizontal_tile == (horizontal_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_last_square + (horizontal_tile * width_of_square), + height_of_square + (height_of_square * vertical_tile)]) + else: + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_square + (horizontal_tile * width_of_square), + height_of_square + (height_of_square * vertical_tile)]) + last_column = [] + for i in range(vertical_tiles): + last_column.append((horizontal_tiles - 1) + horizontal_tiles * i) + + last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles) + + polygons = [] + for x1, y1, x2, y2 in dimensions: + polygons.append([x1, y1, x1, y2, x2, y2, x2, y1]) + + polygon_indices = [] + for i in range((vertical_tiles * horizontal_tiles) - 1): + if i not in last_row and i not in last_column: + polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles]) + + for a, b, c, d in polygon_indices: + dx = v + dy = v + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a] + polygons[a] = [x1, y1, + x2, y2, + x3 + dx, y3 + dy, + x4, y4] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b] + polygons[b] = [x1, y1, + x2 + dx, y2 + dy, + x3, y3, + x4, y4] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c] + polygons[c] = [x1, y1, + x2, y2, + x3, y3, + x4 + dx, y4 + dy] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d] + polygons[d] = [x1 + dx, y1 + dy, + x2, y2, + x3, y3, + x4, y4] + + generated_mesh = [] + for idx, i in enumerate(dimensions): + generated_mesh.append([dimensions[idx], polygons[idx]]) + return img.transform(img.size, PIL.Image.MESH, generated_mesh, resample=PIL.Image.BICUBIC) + + +def zoom(img, v, **__): + #assert 0.1 <= v <= 2 + w, h = img.size + image_zoomed = img.resize((int(round(img.size[0] * v)), + int(round(img.size[1] * v))), + resample=PIL.Image.BICUBIC) + w_zoomed, h_zoomed = image_zoomed.size + + return image_zoomed.crop((math.floor((float(w_zoomed) / 2) - (float(w) / 2)), + math.floor((float(h_zoomed) / 2) - (float(h) / 2)), + math.floor((float(w_zoomed) / 2) + (float(w) / 2)), + math.floor((float(h_zoomed) / 2) + (float(h) / 2)))) + + +def erase(img, v, **__): + """ + distort operation + """ + #assert 0.1<= v <= 1 + w, h = img.size + w_occlusion = int(w * v) + h_occlusion = int(h * v) + if len(img.getbands()) == 1: + rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion) * 255)) + else: + rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion, len(img.getbands())) * 255)) + + random_position_x = random.randint(0, w - w_occlusion) + random_position_y = random.randint(0, h - h_occlusion) + img.paste(rectangle, (random_position_x, random_position_y)) + return img + + +def skew(img, v, **__): + """ + skew operation + """ + #assert -1 <= v <= 1 + w, h = img.size + x1 = 0 + x2 = h + y1 = 0 + y2 = w + original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)] + max_skew_amount = max(w, h) + max_skew_amount = int(math.ceil(max_skew_amount * v)) + skew_amount = max_skew_amount + new_plane = [(y1 - skew_amount, x1), # Top Left + (y2, x1 - skew_amount), # Top Right + (y2 + skew_amount, x2), # Bottom Right + (y1, x2 + skew_amount)] + matrix = [] + for p1, p2 in zip(new_plane, original_plane): + matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + + A = np.matrix(matrix, dtype=np.float) + B = np.array(original_plane).reshape(8) + perspective_skew_coefficients_matrix = np.dot(np.linalg.pinv(A), B) + perspective_skew_coefficients_matrix = np.array(perspective_skew_coefficients_matrix).reshape(8) + + return img.transform(img.size, PIL.Image.PERSPECTIVE, perspective_skew_coefficients_matrix, + resample=PIL.Image.BICUBIC) + + +def blur(img, v, **__): + #assert -3 <= v <= 3 + return img.filter(PIL.ImageFilter.GaussianBlur(v)) + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [AutoAugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_transform(config_str, hparams): + """ + rand selcet transform operation + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + config = config_str.split('-') + assert config[0] == 'rand' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'w': + weight_idx = int(val) + else: + assert False, 'Unknown RandAugment config section' + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + + final_result = RandAugment(ra_ops, num_layers, choice_weights=choice_weights) + return final_result + + +LEVEL_TO_ARG = { + 'Distort': _distort_level_to_arg, + 'Zoom': _zoom_level_to_arg, + 'Blur': _blur_level_to_arg, + 'Skew': _skew_level_to_arg, + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Rotate': _rotate_level_to_arg, + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'PosterizeResearch': _posterize_research_level_to_arg, + 'PosterizeTpu': _posterize_tpu_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, +} + +NAME_TO_OP = { + 'Distort': distort, + 'Zoom': zoom, + 'Blur': blur, + 'Skew': skew, + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'PosterizeOriginal': posterize, + 'PosterizeResearch': posterize, + 'PosterizeTpu': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AutoAugmentOp: + """ + AutoAugmentOp class + """ + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + ) + self.magnitude_std = self.hparams.get('magnitude_std', 0) + + def __call__(self, img): + if random.random() > self.prob: + return img + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() + return self.aug_fn(img, *level_args, **self.kwargs) + + +class RandAugment: + """ + rand augment class + """ + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + ops = np.random.choice( + self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + for op in ops: + img = op(img) + return img diff --git a/model_zoo/research/cv/Glore_resnet200/train.py b/model_zoo/research/cv/Glore_resnet200/train.py new file mode 100644 index 00000000000..1918a9f2ae6 --- /dev/null +++ b/model_zoo/research/cv/Glore_resnet200/train.py @@ -0,0 +1,179 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train glore_resnet200 on Imagenet2012######################## +python train.py +""" +import os +import random +import argparse +import ast +import numpy as np +from mindspore import Tensor +from mindspore import context +from mindspore import dataset as de +from mindspore.train.model import Model, ParallelMode +from mindspore.communication import management as MultiAscend +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init +import mindspore.nn as nn +import mindspore.common.initializer as weight_init +from src.lr_generator import get_lr +from src.glore_resnet200 import glore_resnet200 +from src.dataset import create_dataset_ImageNet as get_dataset +from src.config import config +from src.loss import SoftmaxCrossEntropyExpand + + +parser = argparse.ArgumentParser(description='Image classification with glore_resnet200') +parser.add_argument('--use_glore', type=ast.literal_eval, default=True, help='Enable GloreUnit') +parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') +parser.add_argument('--data_url', type=str, default=None, + help='Dataset path') +parser.add_argument('--train_url', type=str) +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +parser.add_argument('--device_id', type=int, default=0) +parser.add_argument('--pre_trained', type=ast.literal_eval, default=False) +parser.add_argument('--pre_ckpt_path', type=str, + default='') +parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train') +parser.add_argument('--isModelArts', type=ast.literal_eval, default=True) +args_opt = parser.parse_args() + +if args_opt.isModelArts: + import moxing as mox + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +if __name__ == '__main__': + + target = args_opt.device_target + ckpt_save_dir = config.save_checkpoint_path + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + if args_opt.run_distribute: + if target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id, enable_auto_mixed_precision=True) + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, + auto_parallel_search_mode="recursive_programming") + init() + else: + if target == "Ascend": + device_id = args_opt.device_id + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False, + device_id=device_id) + + train_dataset_path = args_opt.data_url + if args_opt.isModelArts: + # download dataset from obs to cache + mox.file.copy_parallel(src_url=args_opt.data_url, dst_url='/cache/dataset/device_' + os.getenv('DEVICE_ID')) + train_dataset_path = '/cache/dataset/device_' + os.getenv('DEVICE_ID') + # create dataset + dataset = get_dataset(dataset_path=train_dataset_path, do_train=True, use_randaugment=True, repeat_num=1, + batch_size=config.batch_size, target=target) + step_size = dataset.get_dataset_size() + + # define net + + net = glore_resnet200(class_num=config.class_num, use_glore=args_opt.use_glore) + + # init weight + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_ckpt_path) + load_param_into_net(net, param_dict) + else: + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype) + if isinstance(cell, nn.Dense): + cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype) + + # init lr + lr = get_lr(lr_init=config.lr_init, + lr_end=config.lr_end, + lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, + total_epochs=config.epoch_size, + steps_per_epoch=step_size, + lr_decay_mode=config.lr_decay_mode) + lr = Tensor(lr) + + # + # define opt + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + net_opt = nn.SGD(group_params, learning_rate=lr, momentum=config.momentum, weight_decay=config.weight_decay, + loss_scale=config.loss_scale, nesterov=True) + + # define loss, model + loss = SoftmaxCrossEntropyExpand(sparse=True) + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + model = Model(net, loss_fn=loss, optimizer=net_opt, loss_scale_manager=loss_scale) + + # define callbacks + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossMonitor() + cb = [time_cb, loss_cb] + rank_size = os.getenv("RANK_SIZE") + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + if args_opt.isModelArts: + save_checkpoint_path = '/cache/train_output/checkpoint' + if rank_size is None or int(rank_size) == 1: + ckpt_cb = ModelCheckpoint(prefix='glore_resnet200', + directory=save_checkpoint_path, + config=config_ck) + cb += [ckpt_cb] + if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0: + ckpt_cb = ModelCheckpoint(prefix='glore_resnet200', + directory=save_checkpoint_path, + config=config_ck) + cb += [ckpt_cb] + else: + if rank_size is None or int(rank_size) == 1: + ckpt_cb = ModelCheckpoint(prefix='glore_resnet200', + directory=os.path.join('./', 'ckpt_{}'.format(os.getenv("DEVICE_ID"))), + config=config_ck) + cb += [ckpt_cb] + if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0: + ckpt_cb = ModelCheckpoint(prefix='glore_resnet200', + directory=os.path.join('./', 'ckpt_{}'.format(os.getenv("DEVICE_ID"))), + config=config_ck) + cb += [ckpt_cb] + + model.train(config.epoch_size - config.pretrain_epoch_size, dataset, + callbacks=cb, dataset_sink_mode=True) + if args_opt.isModelArts: + mox.file.copy_parallel(src_url='/cache/train_output/checkpoint', dst_url=args_opt.train_url)