fix pretrained model in psenet network

This commit is contained in:
anzhengqi 2021-08-03 17:39:39 +08:00
parent 11cf74e6e8
commit 8d7f979bf4
7 changed files with 558 additions and 165 deletions

View File

@ -2,6 +2,7 @@
- [PSENet Description](#PSENet-description)
- [Dataset](#dataset)
- [Pretrained Model](#Pretrained-model)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
@ -15,6 +16,7 @@
- [Distributed GPU Training](#distributed-gpu-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Result](#result)
- [Inference Process](#inference-process)
- [Export MindIR](#export-mindir)
- [Infer on Ascend310](#infer-on-ascend310)
@ -48,6 +50,19 @@ Dataset used: [ICDAR2015](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalizatio
A training set of 1000 images containing about 4500 readable words
A testing set containing about 2000 readable words
unzip dataset files and needn't transform to mindrecord.
# [Pretrained Model](#contents)
download pytorch pretrained model: [resnet50-19c8e357.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth)
transform pytorch model to mindspore model
```shell
cd src
python psenet_model_torch2mindspore.py --torch_file=/path_to_model/resnet50-19c8e357.pth --output_path=../
```
# [Environment Requirements](#contents)
- HardwareAscend or GPU
@ -61,34 +76,100 @@ A testing set containing about 2000 readable words
- install [pyblind11](https://github.com/pybind/pybind11)
- install [Opencv3.4](https://docs.opencv.org/3.4.9/)
```shell
# install pybind11
pip install pybind11
# install opencv3.4.9
wget https://github.com/opencv/opencv/archive/3.4.9.zip
unzip 3.4.9.zip
cd opencv-3.4.9
mkdir build
cmake -D CMAKE_BUILD_TYPE=Release -D CMAKE_INSTALL_PREFIX=/usr/local -D WITH_WEBP=OFF ..
make -j4 # -j指定线程数用户根据机器配置修改参数
make install
# export environment variables
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/usr/local/include
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64
```
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```python
```shell
# run distributed training example
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
#download opencv library
download pyblind11, opencv3.4
#install pyblind11 opencv3.4
setup pyblind11(install the library by the pip command)
setup opencv3.4(compile source code install the library)
#enter the path ,run Makefile to product file
#enter the path ,run Makefile
cd ./src/ETSNET/pse/;make
#run test.py
python test.py --ckpt pretrained_model.ckpt --TEST_ROOT_DIR [test root path]
#download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization).
#click "My Methods" button,then download Evaluation Scripts
#go to Evaluation Process for details
download script.py
# run evaluation example
bash scripts/run_eval_ascend.sh
```
- running on ModelArts
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
```python
# Example of using distributed training on modelarts :
# Data set storage method
# ├── ICDAR2015 # dir
# ├── train # train dir
# ├── ic15 # train_dataset dir
# ├── ch4_training_images
# ├── ch4_training_localization_transcription_gt
# ├── train_predtrained # predtrained dir
# ├── eval # eval dir
# ├── ic15 # eval dataset dir
# ├── ch4_test_images
# ├── challenge4_Test_Task1_GT
# ├── checkpoint # ckpt files dir
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
# a. set "enable_modelarts=True" 。
# set "run_distribute=True"
# set "TRAIN_MODEL_SAVE_PATH=/cache/train/outputs_imagenet/"
# set "TRAIN_ROOT_DIR=/cache/data/ic15/"
# set "pre_trained=/cache/data/train_predtrained/pred file name" Without pre-training weights train_pretrained=""
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
# Set the parameters required by method a on the modelarts interface
# Note: The path parameter does not need to be quoted
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
# (3) Set the code path on the modelarts interface "/path/psenet"。
# (4) Set the model's startup file on the modelarts interface "train.py" 。
# (5) Set the data path of the model on the modelarts interface ".../ICDAR2015/train"(choices ICDAR2015/train Folder path) ,
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
# (6) start trainning the model。
# Example of using model inference on modelarts
# (1) Place the trained model to the corresponding position of the bucket。
# (2) chocie a or b。
# a. set "enable_modelarts=True" 。
# set "TEST_ROOT_DIR=/cache/data/ic15/"
# set "ckpt=/cache/data/checkpoint/ckpt file"
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
# Set the parameters required by method a on the modelarts interface
# Note: The path parameter does not need to be quoted
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
# (4) Set the code path on the modelarts interface "/path/psenet"。
# (5) Set the model's startup file on the modelarts interface "eval.py" 。
# (6) Set the data path of the model on the modelarts interface ".../ICDAR2015/eval"(choices ICDAR2015/eval Folder path) ,
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
# (7) Start model inference。
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
@ -156,7 +237,7 @@ Major parameters in default_config.yaml are:
Please follow the instructions in the link below: <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
```shell
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
bash scripts/run_distribute_train.sh [RANK_FILE] [PRETRAINED_PATH] [TRAIN_ROOT_DIR]
```
rank_table_file which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
@ -195,66 +276,27 @@ time: 2021-07-24 04:01:07, epoch: 90, step: 31, loss is 0.58495
### run test code
```test
```shell
python test.py --ckpt [CKPK_PATH] --TEST_ROOT_DIR [TEST_DATA_DIR]
# click [Here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization) to download evaluation scripts
# choose My Methods -> Offline evaluation -> Evaluation Scripts
# download data and put it in /path_to_data
mkdir eval_ic15
ln -s /path_to_data/script_test_ch4_t1_e1-1577983151.zip eval_ic15/script_test_ch4_t1_e1-1577983151.zip
cd eval_ic15
unzip script_test_ch4_t1_e1-1577983151.zip
cd ..
sh ./script/run_eval_ascend.sh
python test.py --ckpt [CKPK PATH] --TEST_ROOT_DIR [TEST DATA DIR]
```
- running on ModelArts
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
### [Result](#contents)
```python
# Example of using distributed training on modelarts :
# Data set storage method
# ├── ICDAR2015 # dir
# ├── train # train dir
# ├── ic15 # train_dataset dir
# ├── ch4_training_images
# ├── ch4_training_localization_transcription_gt
# ├── train_predtrained # predtrained dir
# ├── eval # eval dir
# ├── ic15 # eval dataset dir
# ├── ch4_test_images
# ├── challenge4_Test_Task1_GT
# ├── checkpoint # ckpt files dir
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
# a. set "enable_modelarts=True" 。
# set "run_distribute=True"
# set "TRAIN_MODEL_SAVE_PATH=/cache/train/outputs_imagenet/"
# set "TRAIN_ROOT_DIR=/cache/data/ic15/"
# set "pre_trained=/cache/data/train_predtrained/pred file name" Without pre-training weights train_pretrained=""
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
# Set the parameters required by method a on the modelarts interface
# Note: The path parameter does not need to be quoted
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
# (3) Set the code path on the modelarts interface "/path/psenet"。
# (4) Set the model's startup file on the modelarts interface "train.py" 。
# (5) Set the data path of the model on the modelarts interface ".../ICDAR2015/train"(choices ICDAR2015/train Folder path) ,
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
# (6) start trainning the model。
# Example of using model inference on modelarts
# (1) Place the trained model to the corresponding position of the bucket。
# (2) chocie a or b。
# a. set "enable_modelarts=True" 。
# set "TEST_ROOT_DIR=/cache/data/ic15/"
# set "ckpt=/cache/data/checkpoint/ckpt file"
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
# Set the parameters required by method a on the modelarts interface
# Note: The path parameter does not need to be quoted
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
# (4) Set the code path on the modelarts interface "/path/psenet"。
# (5) Set the model's startup file on the modelarts interface "eval.py" 。
# (6) Set the data path of the model on the modelarts interface ".../ICDAR2015/eval"(choices ICDAR2015/eval Folder path) ,
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
# (7) Start model inference。
```
Calculated!{"precision": 0.8147966668299853"recall"0.8006740491092923"hmean"0.8076736279747451"AP"0}
### Eval Script for ICDAR2015
@ -342,8 +384,9 @@ The `res` folder is generated in the upper-level directory. For details about th
| Loss Function | LossCallBack |
| outputs | probability |
| Loss | 0.35 |
| Speed | 1pc: 444 ms/step; 8pcs: 446 ms/step |
| Total time | 1pc: 75.48 h; 8pcs: 7.11 h |
| Parameters | batch_size = 4 |
| Speed | 1pc: 444 ms/step(fps: 9.0); 8pcs: 446 ms/step(fps: 71) |
| Total time | 1pc: 75.48 h; 8pcs: 7.11 h |
| Parameters (M) | 27.36 |
| Checkpoint for Fine tuning | 109.44M (.ckpt file) |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/psenet> |

View File

@ -5,6 +5,7 @@
- [PSENet示例](#psenet示例)
- [概述](#概述)
- [数据集](#数据集)
- [预训练模型](#预训练模型)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
@ -14,9 +15,7 @@
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [运行测试代码](#运行测试代码)
- [ICDAR2015评估脚本](#icdar2015评估脚本)
- [用法](#用法)
- [结果](#结果)
- [结果](#结果)
- [推理过程](#推理过程)
- [导出MindIR](#导出mindir)
- [在Ascend310执行推理](#在ascend310执行推理)
@ -48,6 +47,21 @@
训练集包括约4500个可读单词的1000张图像。
测试集约2000个可读单词。
下载得到的训练和推理数据解压后备用不需要转为mindrecord数据
# 预训练模型
下载pytorch的预训练模型: [resnet50-19c8e357.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth)
将pytorch模型转为mindspore模型
```shell
cd src
python psenet_model_torch2mindspore.py --torch_file=/path_to_model/resnet50-19c8e357.pth --output_path=../
```
执行完成src的上层目录得到文件pretrained_model.ckpt文件用于接下来的训练
# 环境要求
- 硬件昇腾处理器Ascend
@ -62,36 +76,100 @@
- 安装[pyblind11](https://github.com/pybind/pybind11)
- 安装[Opencv3.4](https://docs.opencv.org/3.4.9/)
```shell
# 使用pip安装pybind11
pip install pybind11
# 使用源码安装opencv3.4.9
wget https://github.com/opencv/opencv/archive/3.4.9.zip
unzip 3.4.9.zip
cd opencv-3.4.9
mkdir build
cmake -D CMAKE_BUILD_TYPE=Release -D CMAKE_INSTALL_PREFIX=/usr/local -D WITH_WEBP=OFF ..
make -j4 # -j指定线程数用户根据机器配置修改参数
make install
# opencv安装在/usr/local目录下将该目录添加到环境变量中
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/usr/local/include
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64
```
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
```python
```shell
# 分布式训练运行示例
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
# 第一个参数为rank_table文件第二个参数为生成的预训练模型第三个参数为下载的训练数据集
bash scripts/run_distribute_train.sh [RANK_FILE] [PRETRAINED_PATH] [TRAIN_ROOT_DIR]
# 下载opencv库
download pyblind11, opencv3.4
# 安装pyblind11 opencv3.4
setup pyblind11(install the library by the pip command)
setup opencv3.4(compile source code install the library)
# 单击[此处](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)下载评估方法
# 点击"我的方法"按钮,下载评估脚本
# 输入路径运行Makefile找到产品文件
# 进入路径运行Makefile
cd ./src/ETSNET/pse/;make clean&&make
# 运行test.py
python test.py --ckpt pretrained_model.ckpt --TEST_ROOT_DIR [test root path]
python test.py --ckpt [CKPK_PATH] --TEST_ROOT_DIR [TEST_DATA_DIR]
# 具体见评估过程
download script.py
# 运行评估示例
bash scripts/run_eval_ascend.sh
```
- 如果要在modelarts上进行模型的训练可以参考modelarts的[官方指导文档](https://support.huaweicloud.com/modelarts/) 开始进行模型的训练和推理,具体操作如下:
```ModelArts
# 在ModelArts上使用分布式训练示例:
# 数据集存放方式
# ├── ICDAR2015 # dir
# ├── train # train dir
# ├── ic15 # train_dataset dir
# ├── ch4_training_images
# ├── ch4_training_localization_transcription_gt
# ├── train_predtrained # predtrained dir
# ├── eval # eval dir
# ├── ic15 # eval dataset dir
# ├── ch4_test_images
# ├── challenge4_Test_Task1_GT
# ├── checkpoint # ckpt files dir
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
# a. 设置 "enable_modelarts=True"
# 设置 "run_distribute=True"
# 设置 "TRAIN_MODEL_SAVE_PATH=/cache/train/outputs/"
# 设置 "TRAIN_ROOT_DIR=/cache/data/ic15/"
# 设置 "pre_trained=/cache/data/train_predtrained/pred file name" 如果没有预训练权重 pre_trained=""
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
# 在modelarts的界面上设置方法a所需要的参数
# 注意:路径参数不需要加引号
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
# (3) 在modelarts的界面上设置代码的路径 "/path/psenet"。
# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。
# (5) 在modelarts的界面上设置模型的数据路径 ".../ICDAR2015/train"(选择ICDAR2015/train文件夹路径) ,
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
# (6) 开始模型的训练。
# 在modelarts上使用模型推理的示例
# (1) 把训练好的模型地方到桶的对应位置。
# (2) 选择a或者b其中一种方式。
# a.设置 "enable_modelarts=True"
# 设置 "TEST_ROOT_DIR=/cache/data/ic15"
# 设置 "ckpt=/cache/data/checkpoint/ckpt file"
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
# 在modelarts的界面上设置方法a所需要的参数
# 注意:路径参数不需要加引号
# (3) 设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
# (4) 在modelarts的界面上设置代码的路径 "/path/psenet"。
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。
# (6) 在modelarts的界面上设置模型的数据路径 "../ICDAR2015/eval"(选择ICDAR2015/eval文件夹路径) ,
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
# (7) 开始模型的推理。
```
## 脚本说明
## 脚本和样例代码
@ -153,7 +231,8 @@ bash scripts/run_eval_ascend.sh
请遵循链接中的说明:[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
```shell
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
# 第一个参数为rank_table文件第二个参数为生成的预训练模型第三个参数为下载的训练数据集
bash scripts/run_distribute_train.sh [RANK_FILE] [PRETRAINED_PATH] [TRAIN_ROOT_DIR]
```
上述shell脚本将在后台运行分布训练。可以通过`device[X]/test_*.log`文件查看结果。
@ -173,81 +252,24 @@ device_1/log:epcoh 2, step: 40loss is 0.76629
### 运行测试代码
```test
python test.py --ckpt [CKPK PATH] --TEST_ROOT_DIR [TEST DATA DIR]
```
- 如果要在modelarts上进行模型的训练可以参考modelarts的[官方指导文档](https://support.huaweicloud.com/modelarts/) 开始进行模型的训练和推理,具体操作如下:
```ModelArts
# 在ModelArts上使用分布式训练示例:
# 数据集存放方式
# ├── ICDAR2015 # dir
# ├── train # train dir
# ├── ic15 # train_dataset dir
# ├── ch4_training_images
# ├── ch4_training_localization_transcription_gt
# ├── train_predtrained # predtrained dir
# ├── eval # eval dir
# ├── ic15 # eval dataset dir
# ├── ch4_test_images
# ├── challenge4_Test_Task1_GT
# ├── checkpoint # ckpt files dir
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
# a. 设置 "enable_modelarts=True"
# 设置 "run_distribute=True"
# 设置 "TRAIN_MODEL_SAVE_PATH=/cache/train/outputs/"
# 设置 "TRAIN_ROOT_DIR=/cache/data/ic15/"
# 设置 "pre_trained=/cache/data/train_predtrained/pred file name" 如果没有预训练权重 pre_trained=""
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
# 在modelarts的界面上设置方法a所需要的参数
# 注意:路径参数不需要加引号
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
# (3) 在modelarts的界面上设置代码的路径 "/path/psenet"。
# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。
# (5) 在modelarts的界面上设置模型的数据路径 ".../ICDAR2015/train"(选择ICDAR2015/train文件夹路径) ,
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
# (6) 开始模型的训练。
# 在modelarts上使用模型推理的示例
# (1) 把训练好的模型地方到桶的对应位置。
# (2) 选择a或者b其中一种方式。
# a.设置 "enable_modelarts=True"
# 设置 "TEST_ROOT_DIR=/cache/data/ic15"
# 设置 "ckpt=/cache/data/checkpoint/ckpt file"
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
# 在modelarts的界面上设置方法a所需要的参数
# 注意:路径参数不需要加引号
# (3) 设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
# (4) 在modelarts的界面上设置代码的路径 "/path/psenet"。
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。
# (6) 在modelarts的界面上设置模型的数据路径 "../ICDAR2015/eval"(选择ICDAR2015/eval文件夹路径) ,
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
# (7) 开始模型的推理。
```
### ICDAR2015评估脚本
#### 用法
第一步:单击[此处](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)下载评估方法。
第二步:单击"我的方法"按钮,下载评估脚本。
第三步:建议将评估方法根符号链接到$MINDSPORE/model_zoo/psenet/eval_ic15/。如果您的文件夹结构不同,您可能需要更改评估脚本文件中的相应路径。
```shell
bash ./script/run_eval_ascend.sh.sh
# 第一个参数为训练得到的模型文件,第二个参数为下载得到的推理数据集
python test.py --ckpt [CKPK_PATH] --TEST_ROOT_DIR [TEST_DATA_DIR]
# 单击[此处](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)下载评估方法
# 点击"My Methods"按钮选择Offline evaluation -> Evaluation Scripts
# 下载完成后,将数据放在/path_to_data路径
mkdir eval_ic15
ln -s /path_to_data/script_test_ch4_t1_e1-1577983151.zip eval_ic15/script_test_ch4_t1_e1-1577983151.zip
cd eval_ic15
unzip script_test_ch4_t1_e1-1577983151.zip
cd ..
bash ./script/run_eval_ascend.sh
```
#### 结果
### 结果
Calculated!{"precision": 0.8147966668299853"recall"0.8006740491092923"hmean"0.8076736279747451"AP"0}
@ -317,7 +339,8 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
| 损失函数 | LossCallBack |
| 输出 | 概率 |
| 损失 | 0.35 |
| 速度 | 1卡444毫秒/步8卡446毫秒/步
| 训练参数 | batch_size = 4 |
| 速度 | 1卡444毫秒/步(fps: 9.0)8卡446毫秒/步(fps: 71) |
| 总时间 | 1卡75.48小时8卡7.11小时|
| 参数(M) | 27.36 |
| 微调检查点 | 109.44M .ckpt file |

View File

@ -2,3 +2,5 @@ numpy
opencv-python
pillow
pyyaml
Polygon3
pyclipper

View File

@ -13,8 +13,7 @@
# limitations under the License.
# ============================================================================
mindspore_home = ${MINDSPORE_HOME}
CXXFLAGS = -I include -I ${mindspore_home}/model_zoo/official/cv/psenet -std=c++11 -O3
CXXFLAGS = -std=c++11 -O3
CXX_SOURCES = adaptor.cpp
opencv_home = ${OPENCV_HOME}
OPENCV = -I$(opencv_home)/include -L$(opencv_home)/lib64 -lopencv_superres -lopencv_ml -lopencv_objdetect \

View File

@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ETSNET/pse/adaptor.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
@ -26,6 +25,7 @@
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "./adaptor.h"
using std::vector;
using std::queue;

View File

@ -0,0 +1,326 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""transform pytorch checkpoint to mindspore checkpoint"""
import os
import argparse
import torch
from mindspore import Tensor
from mindspore.train.serialization import save_checkpoint
parser = argparse.ArgumentParser(description="transform pytorch checkpoint to mindspore checkpoint")
parser.add_argument("--torch_file", type=str, required=True, help="input pytorch checkpoint filename")
parser.add_argument("--output_path", type=str, required=True, help="output mindspore checkpoint path")
args = parser.parse_args()
weights_dict = {
"feature_extractor.conv1.weight": "conv1.weight",
"feature_extractor.bn1.moving_mean": "bn1.running_mean",
"feature_extractor.bn1.moving_variance": "bn1.running_var",
"feature_extractor.bn1.gamma": "bn1.weight",
"feature_extractor.bn1.beta": "bn1.bias",
"feature_extractor.layer1.0.conv1.weight": "layer1.0.conv1.weight",
"feature_extractor.layer1.0.bn1.moving_mean": "layer1.0.bn1.running_mean",
"feature_extractor.layer1.0.bn1.moving_variance": "layer1.0.bn1.running_var",
"feature_extractor.layer1.0.bn1.gamma": "layer1.0.bn1.weight",
"feature_extractor.layer1.0.bn1.beta": "layer1.0.bn1.bias",
"feature_extractor.layer1.0.conv2.weight": "layer1.0.conv2.weight",
"feature_extractor.layer1.0.bn2.moving_mean": "layer1.0.bn2.running_mean",
"feature_extractor.layer1.0.bn2.moving_variance": "layer1.0.bn2.running_var",
"feature_extractor.layer1.0.bn2.gamma": "layer1.0.bn2.weight",
"feature_extractor.layer1.0.bn2.beta": "layer1.0.bn2.bias",
"feature_extractor.layer1.0.conv3.weight": "layer1.0.conv3.weight",
"feature_extractor.layer1.0.bn3.moving_mean": "layer1.0.bn3.running_mean",
"feature_extractor.layer1.0.bn3.moving_variance": "layer1.0.bn3.running_var",
"feature_extractor.layer1.0.bn3.gamma": "layer1.0.bn3.weight",
"feature_extractor.layer1.0.bn3.beta": "layer1.0.bn3.bias",
"feature_extractor.layer1.0.conv_down_sample.weight": "layer1.0.downsample.0.weight",
"feature_extractor.layer1.0.bn_down_sample.moving_mean": "layer1.0.downsample.1.running_mean",
"feature_extractor.layer1.0.bn_down_sample.moving_variance": "layer1.0.downsample.1.running_var",
"feature_extractor.layer1.0.bn_down_sample.gamma": "layer1.0.downsample.1.weight",
"feature_extractor.layer1.0.bn_down_sample.beta": "layer1.0.downsample.1.bias",
"feature_extractor.layer1.1.conv1.weight": "layer1.1.conv1.weight",
"feature_extractor.layer1.1.bn1.moving_mean": "layer1.1.bn1.running_mean",
"feature_extractor.layer1.1.bn1.moving_variance": "layer1.1.bn1.running_var",
"feature_extractor.layer1.1.bn1.gamma": "layer1.1.bn1.weight",
"feature_extractor.layer1.1.bn1.beta": "layer1.1.bn1.bias",
"feature_extractor.layer1.1.conv2.weight": "layer1.1.conv2.weight",
"feature_extractor.layer1.1.bn2.moving_mean": "layer1.1.bn2.running_mean",
"feature_extractor.layer1.1.bn2.moving_variance": "layer1.1.bn2.running_var",
"feature_extractor.layer1.1.bn2.gamma": "layer1.1.bn2.weight",
"feature_extractor.layer1.1.bn2.beta": "layer1.1.bn2.bias",
"feature_extractor.layer1.1.conv3.weight": "layer1.1.conv3.weight",
"feature_extractor.layer1.1.bn3.moving_mean": "layer1.1.bn3.running_mean",
"feature_extractor.layer1.1.bn3.moving_variance": "layer1.1.bn3.running_var",
"feature_extractor.layer1.1.bn3.gamma": "layer1.1.bn3.weight",
"feature_extractor.layer1.1.bn3.beta": "layer1.1.bn3.bias",
"feature_extractor.layer1.2.conv1.weight": "layer1.2.conv1.weight",
"feature_extractor.layer1.2.bn1.moving_mean": "layer1.2.bn1.running_mean",
"feature_extractor.layer1.2.bn1.moving_variance": "layer1.2.bn1.running_var",
"feature_extractor.layer1.2.bn1.gamma": "layer1.2.bn1.weight",
"feature_extractor.layer1.2.bn1.beta": "layer1.2.bn1.bias",
"feature_extractor.layer1.2.conv2.weight": "layer1.2.conv2.weight",
"feature_extractor.layer1.2.bn2.moving_mean": "layer1.2.bn2.running_mean",
"feature_extractor.layer1.2.bn2.moving_variance": "layer1.2.bn2.running_var",
"feature_extractor.layer1.2.bn2.gamma": "layer1.2.bn2.weight",
"feature_extractor.layer1.2.bn2.beta": "layer1.2.bn2.bias",
"feature_extractor.layer1.2.conv3.weight": "layer1.2.conv3.weight",
"feature_extractor.layer1.2.bn3.moving_mean": "layer1.2.bn3.running_mean",
"feature_extractor.layer1.2.bn3.moving_variance": "layer1.2.bn3.running_var",
"feature_extractor.layer1.2.bn3.gamma": "layer1.2.bn3.weight",
"feature_extractor.layer1.2.bn3.beta": "layer1.2.bn3.bias",
"feature_extractor.layer2.0.conv1.weight": "layer2.0.conv1.weight",
"feature_extractor.layer2.0.bn1.moving_mean": "layer2.0.bn1.running_mean",
"feature_extractor.layer2.0.bn1.moving_variance": "layer2.0.bn1.running_var",
"feature_extractor.layer2.0.bn1.gamma": "layer2.0.bn1.weight",
"feature_extractor.layer2.0.bn1.beta": "layer2.0.bn1.bias",
"feature_extractor.layer2.0.conv2.weight": "layer2.0.conv2.weight",
"feature_extractor.layer2.0.bn2.moving_mean": "layer2.0.bn2.running_mean",
"feature_extractor.layer2.0.bn2.moving_variance": "layer2.0.bn2.running_var",
"feature_extractor.layer2.0.bn2.gamma": "layer2.0.bn2.weight",
"feature_extractor.layer2.0.bn2.beta": "layer2.0.bn2.bias",
"feature_extractor.layer2.0.conv3.weight": "layer2.0.conv3.weight",
"feature_extractor.layer2.0.bn3.moving_mean": "layer2.0.bn3.running_mean",
"feature_extractor.layer2.0.bn3.moving_variance": "layer2.0.bn3.running_var",
"feature_extractor.layer2.0.bn3.gamma": "layer2.0.bn3.weight",
"feature_extractor.layer2.0.bn3.beta": "layer2.0.bn3.bias",
"feature_extractor.layer2.0.conv_down_sample.weight": "layer2.0.downsample.0.weight",
"feature_extractor.layer2.0.bn_down_sample.moving_mean": "layer2.0.downsample.1.running_mean",
"feature_extractor.layer2.0.bn_down_sample.moving_variance": "layer2.0.downsample.1.running_var",
"feature_extractor.layer2.0.bn_down_sample.gamma": "layer2.0.downsample.1.weight",
"feature_extractor.layer2.0.bn_down_sample.beta": "layer2.0.downsample.1.bias",
"feature_extractor.layer2.1.conv1.weight": "layer2.1.conv1.weight",
"feature_extractor.layer2.1.bn1.moving_mean": "layer2.1.bn1.running_mean",
"feature_extractor.layer2.1.bn1.moving_variance": "layer2.1.bn1.running_var",
"feature_extractor.layer2.1.bn1.gamma": "layer2.1.bn1.weight",
"feature_extractor.layer2.1.bn1.beta": "layer2.1.bn1.bias",
"feature_extractor.layer2.1.conv2.weight": "layer2.1.conv2.weight",
"feature_extractor.layer2.1.bn2.moving_mean": "layer2.1.bn2.running_mean",
"feature_extractor.layer2.1.bn2.moving_variance": "layer2.1.bn2.running_var",
"feature_extractor.layer2.1.bn2.gamma": "layer2.1.bn2.weight",
"feature_extractor.layer2.1.bn2.beta": "layer2.1.bn2.bias",
"feature_extractor.layer2.1.conv3.weight": "layer2.1.conv3.weight",
"feature_extractor.layer2.1.bn3.moving_mean": "layer2.1.bn3.running_mean",
"feature_extractor.layer2.1.bn3.moving_variance": "layer2.1.bn3.running_var",
"feature_extractor.layer2.1.bn3.gamma": "layer2.1.bn3.weight",
"feature_extractor.layer2.1.bn3.beta": "layer2.1.bn3.bias",
"feature_extractor.layer2.2.conv1.weight": "layer2.2.conv1.weight",
"feature_extractor.layer2.2.bn1.moving_mean": "layer2.2.bn1.running_mean",
"feature_extractor.layer2.2.bn1.moving_variance": "layer2.2.bn1.running_var",
"feature_extractor.layer2.2.bn1.gamma": "layer2.2.bn1.weight",
"feature_extractor.layer2.2.bn1.beta": "layer2.2.bn1.bias",
"feature_extractor.layer2.2.conv2.weight": "layer2.2.conv2.weight",
"feature_extractor.layer2.2.bn2.moving_mean": "layer2.2.bn2.running_mean",
"feature_extractor.layer2.2.bn2.moving_variance": "layer2.2.bn2.running_var",
"feature_extractor.layer2.2.bn2.gamma": "layer2.2.bn2.weight",
"feature_extractor.layer2.2.bn2.beta": "layer2.2.bn2.bias",
"feature_extractor.layer2.2.conv3.weight": "layer2.2.conv3.weight",
"feature_extractor.layer2.2.bn3.moving_mean": "layer2.2.bn3.running_mean",
"feature_extractor.layer2.2.bn3.moving_variance": "layer2.2.bn3.running_var",
"feature_extractor.layer2.2.bn3.gamma": "layer2.2.bn3.weight",
"feature_extractor.layer2.2.bn3.beta": "layer2.2.bn3.bias",
"feature_extractor.layer2.3.conv1.weight": "layer2.3.conv1.weight",
"feature_extractor.layer2.3.bn1.moving_mean": "layer2.3.bn1.running_mean",
"feature_extractor.layer2.3.bn1.moving_variance": "layer2.3.bn1.running_var",
"feature_extractor.layer2.3.bn1.gamma": "layer2.3.bn1.weight",
"feature_extractor.layer2.3.bn1.beta": "layer2.3.bn1.bias",
"feature_extractor.layer2.3.conv2.weight": "layer2.3.conv2.weight",
"feature_extractor.layer2.3.bn2.moving_mean": "layer2.3.bn2.running_mean",
"feature_extractor.layer2.3.bn2.moving_variance": "layer2.3.bn2.running_var",
"feature_extractor.layer2.3.bn2.gamma": "layer2.3.bn2.weight",
"feature_extractor.layer2.3.bn2.beta": "layer2.3.bn2.bias",
"feature_extractor.layer2.3.conv3.weight": "layer2.3.conv3.weight",
"feature_extractor.layer2.3.bn3.moving_mean": "layer2.3.bn3.running_mean",
"feature_extractor.layer2.3.bn3.moving_variance": "layer2.3.bn3.running_var",
"feature_extractor.layer2.3.bn3.gamma": "layer2.3.bn3.weight",
"feature_extractor.layer2.3.bn3.beta": "layer2.3.bn3.bias",
"feature_extractor.layer3.0.conv1.weight": "layer3.0.conv1.weight",
"feature_extractor.layer3.0.bn1.moving_mean": "layer3.0.bn1.running_mean",
"feature_extractor.layer3.0.bn1.moving_variance": "layer3.0.bn1.running_var",
"feature_extractor.layer3.0.bn1.gamma": "layer3.0.bn1.weight",
"feature_extractor.layer3.0.bn1.beta": "layer3.0.bn1.bias",
"feature_extractor.layer3.0.conv2.weight": "layer3.0.conv2.weight",
"feature_extractor.layer3.0.bn2.moving_mean": "layer3.0.bn2.running_mean",
"feature_extractor.layer3.0.bn2.moving_variance": "layer3.0.bn2.running_var",
"feature_extractor.layer3.0.bn2.gamma": "layer3.0.bn2.weight",
"feature_extractor.layer3.0.bn2.beta": "layer3.0.bn2.bias",
"feature_extractor.layer3.0.conv3.weight": "layer3.0.conv3.weight",
"feature_extractor.layer3.0.bn3.moving_mean": "layer3.0.bn3.running_mean",
"feature_extractor.layer3.0.bn3.moving_variance": "layer3.0.bn3.running_var",
"feature_extractor.layer3.0.bn3.gamma": "layer3.0.bn3.weight",
"feature_extractor.layer3.0.bn3.beta": "layer3.0.bn3.bias",
"feature_extractor.layer3.0.conv_down_sample.weight": "layer3.0.downsample.0.weight",
"feature_extractor.layer3.0.bn_down_sample.moving_mean": "layer3.0.downsample.1.running_mean",
"feature_extractor.layer3.0.bn_down_sample.moving_variance": "layer3.0.downsample.1.running_var",
"feature_extractor.layer3.0.bn_down_sample.gamma": "layer3.0.downsample.1.weight",
"feature_extractor.layer3.0.bn_down_sample.beta": "layer3.0.downsample.1.bias",
"feature_extractor.layer3.1.conv1.weight": "layer3.1.conv1.weight",
"feature_extractor.layer3.1.bn1.moving_mean": "layer3.1.bn1.running_mean",
"feature_extractor.layer3.1.bn1.moving_variance": "layer3.1.bn1.running_var",
"feature_extractor.layer3.1.bn1.gamma": "layer3.1.bn1.weight",
"feature_extractor.layer3.1.bn1.beta": "layer3.1.bn1.bias",
"feature_extractor.layer3.1.conv2.weight": "layer3.1.conv2.weight",
"feature_extractor.layer3.1.bn2.moving_mean": "layer3.1.bn2.running_mean",
"feature_extractor.layer3.1.bn2.moving_variance": "layer3.1.bn2.running_var",
"feature_extractor.layer3.1.bn2.gamma": "layer3.1.bn2.weight",
"feature_extractor.layer3.1.bn2.beta": "layer3.1.bn2.bias",
"feature_extractor.layer3.1.conv3.weight": "layer3.1.conv3.weight",
"feature_extractor.layer3.1.bn3.moving_mean": "layer3.1.bn3.running_mean",
"feature_extractor.layer3.1.bn3.moving_variance": "layer3.1.bn3.running_var",
"feature_extractor.layer3.1.bn3.gamma": "layer3.1.bn3.weight",
"feature_extractor.layer3.1.bn3.beta": "layer3.1.bn3.bias",
"feature_extractor.layer3.2.conv1.weight": "layer3.2.conv1.weight",
"feature_extractor.layer3.2.bn1.moving_mean": "layer3.2.bn1.running_mean",
"feature_extractor.layer3.2.bn1.moving_variance": "layer3.2.bn1.running_var",
"feature_extractor.layer3.2.bn1.gamma": "layer3.2.bn1.weight",
"feature_extractor.layer3.2.bn1.beta": "layer3.2.bn1.bias",
"feature_extractor.layer3.2.conv2.weight": "layer3.2.conv2.weight",
"feature_extractor.layer3.2.bn2.moving_mean": "layer3.2.bn2.running_mean",
"feature_extractor.layer3.2.bn2.moving_variance": "layer3.2.bn2.running_var",
"feature_extractor.layer3.2.bn2.gamma": "layer3.2.bn2.weight",
"feature_extractor.layer3.2.bn2.beta": "layer3.2.bn2.bias",
"feature_extractor.layer3.2.conv3.weight": "layer3.2.conv3.weight",
"feature_extractor.layer3.2.bn3.moving_mean": "layer3.2.bn3.running_mean",
"feature_extractor.layer3.2.bn3.moving_variance": "layer3.2.bn3.running_var",
"feature_extractor.layer3.2.bn3.gamma": "layer3.2.bn3.weight",
"feature_extractor.layer3.2.bn3.beta": "layer3.2.bn3.bias",
"feature_extractor.layer3.3.conv1.weight": "layer3.3.conv1.weight",
"feature_extractor.layer3.3.bn1.moving_mean": "layer3.3.bn1.running_mean",
"feature_extractor.layer3.3.bn1.moving_variance": "layer3.3.bn1.running_var",
"feature_extractor.layer3.3.bn1.gamma": "layer3.3.bn1.weight",
"feature_extractor.layer3.3.bn1.beta": "layer3.3.bn1.bias",
"feature_extractor.layer3.3.conv2.weight": "layer3.3.conv2.weight",
"feature_extractor.layer3.3.bn2.moving_mean": "layer3.3.bn2.running_mean",
"feature_extractor.layer3.3.bn2.moving_variance": "layer3.3.bn2.running_var",
"feature_extractor.layer3.3.bn2.gamma": "layer3.3.bn2.weight",
"feature_extractor.layer3.3.bn2.beta": "layer3.3.bn2.bias",
"feature_extractor.layer3.3.conv3.weight": "layer3.3.conv3.weight",
"feature_extractor.layer3.3.bn3.moving_mean": "layer3.3.bn3.running_mean",
"feature_extractor.layer3.3.bn3.moving_variance": "layer3.3.bn3.running_var",
"feature_extractor.layer3.3.bn3.gamma": "layer3.3.bn3.weight",
"feature_extractor.layer3.3.bn3.beta": "layer3.3.bn3.bias",
"feature_extractor.layer3.4.conv1.weight": "layer3.4.conv1.weight",
"feature_extractor.layer3.4.bn1.moving_mean": "layer3.4.bn1.running_mean",
"feature_extractor.layer3.4.bn1.moving_variance": "layer3.4.bn1.running_var",
"feature_extractor.layer3.4.bn1.gamma": "layer3.4.bn1.weight",
"feature_extractor.layer3.4.bn1.beta": "layer3.4.bn1.bias",
"feature_extractor.layer3.4.conv2.weight": "layer3.4.conv2.weight",
"feature_extractor.layer3.4.bn2.moving_mean": "layer3.4.bn2.running_mean",
"feature_extractor.layer3.4.bn2.moving_variance": "layer3.4.bn2.running_var",
"feature_extractor.layer3.4.bn2.gamma": "layer3.4.bn2.weight",
"feature_extractor.layer3.4.bn2.beta": "layer3.4.bn2.bias",
"feature_extractor.layer3.4.conv3.weight": "layer3.4.conv3.weight",
"feature_extractor.layer3.4.bn3.moving_mean": "layer3.4.bn3.running_mean",
"feature_extractor.layer3.4.bn3.moving_variance": "layer3.4.bn3.running_var",
"feature_extractor.layer3.4.bn3.gamma": "layer3.4.bn3.weight",
"feature_extractor.layer3.4.bn3.beta": "layer3.4.bn3.bias",
"feature_extractor.layer3.5.conv1.weight": "layer3.5.conv1.weight",
"feature_extractor.layer3.5.bn1.moving_mean": "layer3.5.bn1.running_mean",
"feature_extractor.layer3.5.bn1.moving_variance": "layer3.5.bn1.running_var",
"feature_extractor.layer3.5.bn1.gamma": "layer3.5.bn1.weight",
"feature_extractor.layer3.5.bn1.beta": "layer3.5.bn1.bias",
"feature_extractor.layer3.5.conv2.weight": "layer3.5.conv2.weight",
"feature_extractor.layer3.5.bn2.moving_mean": "layer3.5.bn2.running_mean",
"feature_extractor.layer3.5.bn2.moving_variance": "layer3.5.bn2.running_var",
"feature_extractor.layer3.5.bn2.gamma": "layer3.5.bn2.weight",
"feature_extractor.layer3.5.bn2.beta": "layer3.5.bn2.bias",
"feature_extractor.layer3.5.conv3.weight": "layer3.5.conv3.weight",
"feature_extractor.layer3.5.bn3.moving_mean": "layer3.5.bn3.running_mean",
"feature_extractor.layer3.5.bn3.moving_variance": "layer3.5.bn3.running_var",
"feature_extractor.layer3.5.bn3.gamma": "layer3.5.bn3.weight",
"feature_extractor.layer3.5.bn3.beta": "layer3.5.bn3.bias",
"feature_extractor.layer4.0.conv1.weight": "layer4.0.conv1.weight",
"feature_extractor.layer4.0.bn1.moving_mean": "layer4.0.bn1.running_mean",
"feature_extractor.layer4.0.bn1.moving_variance": "layer4.0.bn1.running_var",
"feature_extractor.layer4.0.bn1.gamma": "layer4.0.bn1.weight",
"feature_extractor.layer4.0.bn1.beta": "layer4.0.bn1.bias",
"feature_extractor.layer4.0.conv2.weight": "layer4.0.conv2.weight",
"feature_extractor.layer4.0.bn2.moving_mean": "layer4.0.bn2.running_mean",
"feature_extractor.layer4.0.bn2.moving_variance": "layer4.0.bn2.running_var",
"feature_extractor.layer4.0.bn2.gamma": "layer4.0.bn2.weight",
"feature_extractor.layer4.0.bn2.beta": "layer4.0.bn2.bias",
"feature_extractor.layer4.0.conv3.weight": "layer4.0.conv3.weight",
"feature_extractor.layer4.0.bn3.moving_mean": "layer4.0.bn3.running_mean",
"feature_extractor.layer4.0.bn3.moving_variance": "layer4.0.bn3.running_var",
"feature_extractor.layer4.0.bn3.gamma": "layer4.0.bn3.weight",
"feature_extractor.layer4.0.bn3.beta": "layer4.0.bn3.bias",
"feature_extractor.layer4.0.conv_down_sample.weight": "layer4.0.downsample.0.weight",
"feature_extractor.layer4.0.bn_down_sample.moving_mean": "layer4.0.downsample.1.running_mean",
"feature_extractor.layer4.0.bn_down_sample.moving_variance": "layer4.0.downsample.1.running_var",
"feature_extractor.layer4.0.bn_down_sample.gamma": "layer4.0.downsample.1.weight",
"feature_extractor.layer4.0.bn_down_sample.beta": "layer4.0.downsample.1.bias",
"feature_extractor.layer4.1.conv1.weight": "layer4.1.conv1.weight",
"feature_extractor.layer4.1.bn1.moving_mean": "layer4.1.bn1.running_mean",
"feature_extractor.layer4.1.bn1.moving_variance": "layer4.1.bn1.running_var",
"feature_extractor.layer4.1.bn1.gamma": "layer4.1.bn1.weight",
"feature_extractor.layer4.1.bn1.beta": "layer4.1.bn1.bias",
"feature_extractor.layer4.1.conv2.weight": "layer4.1.conv2.weight",
"feature_extractor.layer4.1.bn2.moving_mean": "layer4.1.bn2.running_mean",
"feature_extractor.layer4.1.bn2.moving_variance": "layer4.1.bn2.running_var",
"feature_extractor.layer4.1.bn2.gamma": "layer4.1.bn2.weight",
"feature_extractor.layer4.1.bn2.beta": "layer4.1.bn2.bias",
"feature_extractor.layer4.1.conv3.weight": "layer4.1.conv3.weight",
"feature_extractor.layer4.1.bn3.moving_mean": "layer4.1.bn3.running_mean",
"feature_extractor.layer4.1.bn3.moving_variance": "layer4.1.bn3.running_var",
"feature_extractor.layer4.1.bn3.gamma": "layer4.1.bn3.weight",
"feature_extractor.layer4.1.bn3.beta": "layer4.1.bn3.bias",
"feature_extractor.layer4.2.conv1.weight": "layer4.2.conv1.weight",
"feature_extractor.layer4.2.bn1.moving_mean": "layer4.2.bn1.running_mean",
"feature_extractor.layer4.2.bn1.moving_variance": "layer4.2.bn1.running_var",
"feature_extractor.layer4.2.bn1.gamma": "layer4.2.bn1.weight",
"feature_extractor.layer4.2.bn1.beta": "layer4.2.bn1.bias",
"feature_extractor.layer4.2.conv2.weight": "layer4.2.conv2.weight",
"feature_extractor.layer4.2.bn2.moving_mean": "layer4.2.bn2.running_mean",
"feature_extractor.layer4.2.bn2.moving_variance": "layer4.2.bn2.running_var",
"feature_extractor.layer4.2.bn2.gamma": "layer4.2.bn2.weight",
"feature_extractor.layer4.2.bn2.beta": "layer4.2.bn2.bias",
"feature_extractor.layer4.2.conv3.weight": "layer4.2.conv3.weight",
"feature_extractor.layer4.2.bn3.moving_mean": "layer4.2.bn3.running_mean",
"feature_extractor.layer4.2.bn3.moving_variance": "layer4.2.bn3.running_var",
"feature_extractor.layer4.2.bn3.gamma": "layer4.2.bn3.weight",
"feature_extractor.layer4.2.bn3.beta": "layer4.2.bn3.bias",
}
torch_param_dict = torch.load(args.torch_file)
ms_params = []
for ms_name, torch_name in weights_dict.items():
if torch_name not in torch_param_dict.keys():
print("param {} not in pytorch checkpoint".format(torch_name))
continue
torch_value = torch_param_dict[torch_name]
np_value = torch_value.data.numpy()
each_param = dict()
each_param["name"] = ms_name
each_param["data"] = Tensor(np_value)
ms_params.append(each_param)
save_checkpoint(ms_params, os.path.join(args.output_path, "pretrained_model.ckpt"))

View File

@ -100,7 +100,7 @@ def train():
if config.pre_trained:
param_dict = load_checkpoint(config.pre_trained)
load_param_into_net(net, param_dict)
load_param_into_net(net, param_dict, strict_load=True)
print('Load Pretrained parameters done!')
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)