forked from mindspore-Ecosystem/mindspore
!11871 add convert checkpoint script for maskrcnn and faster-rcnn
From: @qujianwei Reviewed-by: @c_34,@linqingke Signed-off-by: @c_34
This commit is contained in:
commit
9030e547ef
|
@ -84,7 +84,7 @@ Dataset used: [COCO2017](<https://cocodataset.org/>)
|
|||
```
|
||||
|
||||
2. If your own dataset is used. **Select dataset to other when run script.**
|
||||
Organize the dataset infomation into a TXT file, each row in the file is as follows:
|
||||
Organize the dataset information into a TXT file, each row in the file is as follows:
|
||||
|
||||
```log
|
||||
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
|
||||
|
@ -97,10 +97,14 @@ Dataset used: [COCO2017](<https://cocodataset.org/>)
|
|||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
Note: 1.the first run will generate the mindeocrd file, which will take a long time.
|
||||
2.pretrained model is a resnet50 checkpoint that trained over ImageNet2012.
|
||||
3.VALIDATION_JSON_FILE is label file. CHECKPOINT_PATH is a checkpoint file after trained.
|
||||
2.pretrained model is a resnet50 checkpoint that trained over ImageNet2012.you can train it with [resnet50](https://gitee.com/qujianwei/mindspore/tree/master/model_zoo/official/cv/resnet) scripts in modelzoo, and use src/convert_checkpoint.py to get the pretrain model.
|
||||
3.BACKBONE_MODEL is a checkpoint file trained with [resnet50](https://gitee.com/qujianwei/mindspore/tree/master/model_zoo/official/cv/resnet) scripts in modelzoo.PRETRAINED_MODEL is a checkpoint file after convert.VALIDATION_JSON_FILE is label file. CHECKPOINT_PATH is a checkpoint file after trained.
|
||||
|
||||
```shell
|
||||
|
||||
# convert checkpoint
|
||||
python convert_checkpoint.py --ckpt_file=[BACKBONE_MODEL]
|
||||
|
||||
# standalone training
|
||||
sh run_standalone_train_ascend.sh [PRETRAINED_MODEL]
|
||||
|
||||
|
@ -287,7 +291,7 @@ Eval result will be stored in the example path, whose folder name is "eval". Und
|
|||
python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
`EXPORT_FORMAT` shoule be in ["AIR", "ONNX", "MINDIR"]
|
||||
`EXPORT_FORMAT` should be in ["AIR", "ONNX", "MINDIR"]
|
||||
|
||||
## Inference Process
|
||||
|
||||
|
|
|
@ -100,10 +100,14 @@ Faster R-CNN是一个两阶段目标检测网络,该网络采用RPN,可以
|
|||
注意:
|
||||
|
||||
1. 第一次运行生成MindRecord文件,耗时较长。
|
||||
2. 预训练模型是在ImageNet2012上训练的ResNet-50检查点。
|
||||
3. VALIDATION_JSON_FILE为标签文件。CHECKPOINT_PATH是训练后的检查点文件。
|
||||
2. 预训练模型是在ImageNet2012上训练的ResNet-50检查点。你可以使用ModelZoo中 [resnet50](https://gitee.com/qujianwei/mindspore/tree/master/model_zoo/official/cv/resnet) 脚本来训练, 然后使用src/convert_checkpoint.py把训练好的resnet50的权重文件转换为可加载的权重文件。
|
||||
3. BACKBONE_MODEL是通过modelzoo中的[resnet50](https://gitee.com/qujianwei/mindspore/tree/master/model_zoo/official/cv/resnet)脚本训练的。PRETRAINED_MODEL是经过转换后的权重文件。VALIDATION_JSON_FILE为标签文件。CHECKPOINT_PATH是训练后的检查点文件。
|
||||
|
||||
```shell
|
||||
|
||||
# 权重文件转换
|
||||
python convert_checkpoint.py --ckpt_file=[BACKBONE_MODEL]
|
||||
|
||||
# 单机训练
|
||||
sh run_standalone_train_ascend.sh [PRETRAINED_MODEL]
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""
|
||||
convert resnet50 pretrain model to faster_rcnn backbone pretrain model
|
||||
"""
|
||||
import argparse
|
||||
from mindspore.train.serialization import load_checkpoint, save_checkpoint
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
parser = argparse.ArgumentParser(description='load_ckpt')
|
||||
parser.add_argument('--ckpt_file', type=str, default='', help='ckpt file path')
|
||||
args_opt = parser.parse_args()
|
||||
def load_weights(model_path, use_fp16_weight):
|
||||
"""
|
||||
load resnet50 pretrain checkpoint file.
|
||||
|
||||
Args:
|
||||
model_path (str): resnet50 pretrain checkpoint file .
|
||||
use_fp16_weight(bool): whether save weight into float16.
|
||||
|
||||
Returns:
|
||||
parameter list(list): pretrain model weight list.
|
||||
"""
|
||||
ms_ckpt = load_checkpoint(model_path)
|
||||
weights = {}
|
||||
for msname in ms_ckpt:
|
||||
if msname.startswith("layer") or msname.startswith("conv1") or msname.startswith("bn"):
|
||||
param_name = "backbone." + msname
|
||||
else:
|
||||
param_name = msname
|
||||
if "down_sample_layer.0" in param_name:
|
||||
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
|
||||
if "down_sample-layer.1" in param_name:
|
||||
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
|
||||
weights[param_name] = ms_ckpt[msname].data.asnumpy()
|
||||
if use_fp16_weight:
|
||||
dtype = mstype.float16
|
||||
else:
|
||||
dtype = mstype.float32
|
||||
parameter_dict = {}
|
||||
for name in weights:
|
||||
parameter_dict[name] = Parameter(Tensor(weights[name], dtype), name=name)
|
||||
param_list = []
|
||||
for key, value in parameter_dict.items():
|
||||
param_list.append({"name": key, "data": value})
|
||||
return param_list
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=True)
|
||||
save_checkpoint(parameter_list, "resnet50_backbone.ckpt")
|
|
@ -109,7 +109,7 @@ pip install mmcv=0.2.14
|
|||
Note:
|
||||
1. To speed up data preprocessing, MindSpore provide a data format named MindRecord, hence the first step is to generate MindRecord files based on COCO2017 dataset before training. The process of converting raw COCO2017 dataset to MindRecord format may take about 4 hours.
|
||||
2. For distributed training, a [hccl configuration file](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) with JSON format needs to be created in advance.
|
||||
3. PRETRAINED_CKPT is a resnet50 checkpoint that trained over ImageNet2012.
|
||||
3. PRETRAINED_CKPT is a resnet50 checkpoint that trained over ImageNet2012.you can train it with [resnet50](https://gitee.com/qujianwei/mindspore/tree/master/model_zoo/official/cv/resnet) scripts in modelzoo, and use src/convert_checkpoint.py to get the pretrain checkpoint file.
|
||||
4. For large models like MaskRCNN, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
|
||||
4. Execute eval script.
|
||||
|
@ -205,6 +205,7 @@ bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
|
|||
└─rpn.py # reagion proposal network
|
||||
├─aipp.cfg #aipp config file
|
||||
├─config.py # network configuration
|
||||
├─convert_checkpoint.py # convert resnet50 backbone checkpoint
|
||||
├─dataset.py # dataset utils
|
||||
├─lr_schedule.py # leanring rate geneatore
|
||||
├─network_define.py # network define for maskrcnn
|
||||
|
@ -272,7 +273,7 @@ Usage: bash run_standalone_train.sh [PRETRAINED_MODEL]
|
|||
"neg_iou_thr": 0.3, # negative sample threshold after IOU
|
||||
"pos_iou_thr": 0.7, # positive sample threshold after IOU
|
||||
"min_pos_iou": 0.3, # minimal positive sample threshold after IOU
|
||||
"num_bboxes": 245520, # total bbox numner
|
||||
"num_bboxes": 245520, # total bbox number
|
||||
"num_gts": 128, # total ground truth number
|
||||
"num_expected_neg": 256, # negative sample number
|
||||
"num_expected_pos": 128, # positive sample number
|
||||
|
@ -284,7 +285,7 @@ Usage: bash run_standalone_train.sh [PRETRAINED_MODEL]
|
|||
# roi_alignj
|
||||
"roi_layer": dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2), # ROIAlign parameters
|
||||
"roi_align_out_channels": 256, # ROIAlign out channels size
|
||||
"roi_align_featmap_strides": [4, 8, 16, 32], # stride size for differnt level of ROIAling feature map
|
||||
"roi_align_featmap_strides": [4, 8, 16, 32], # stride size for different level of ROIAling feature map
|
||||
"roi_align_finest_scale": 56, # finest scale ofr ROIAlign
|
||||
"roi_sample_num": 640, # sample number in ROIAling layer
|
||||
|
||||
|
@ -499,7 +500,7 @@ Accumulating evaluation results...
|
|||
python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
`EXPORT_FORMAT` shoule be in ["AIR", "ONNX", "MINDIR"]
|
||||
`EXPORT_FORMAT` should be in ["AIR", "ONNX", "MINDIR"]
|
||||
|
||||
## Inference Process
|
||||
|
||||
|
|
|
@ -111,7 +111,7 @@ pip install mmcv=0.2.14
|
|||
注:
|
||||
1. 为加快数据预处理速度,MindSpore提供了MindRecord数据格式。因此,训练前首先需要生成基于COCO2017数据集的MindRecord文件。COCO2017原始数据集转换为MindRecord格式大概需要4小时。
|
||||
2. 进行分布式训练前,需要提前创建JSON格式的[hccl配置文件](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)。
|
||||
3. PRETRAINED_CKPT是一个ResNet50检查点,通过ImageNet2012训练。
|
||||
3. PRETRAINED_CKPT是一个ResNet50检查点,通过ImageNet2012训练。你可以使用ModelZoo中 [resnet50](https://gitee.com/qujianwei/mindspore/tree/master/model_zoo/official/cv/resnet) 脚本来训练, 然后使用src/convert_checkpoint.py把训练好的resnet50的权重文件转换为可加载的权重文件。
|
||||
|
||||
4. 执行评估脚本。
|
||||
训练结束后,按照如下步骤启动评估:
|
||||
|
@ -199,6 +199,7 @@ bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
|
|||
└─rpn.py # 区域候选网络
|
||||
├─aipp.cfg #aipp 配置文件
|
||||
├─config.py # 网络配置
|
||||
├─convert_checkpoint.py # 转换预训练checkpoint文件
|
||||
├─dataset.py # 数据集工具
|
||||
├─lr_schedule.py # 学习率生成器
|
||||
├─network_define.py # MaskRCNN的网络定义
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""
|
||||
convert resnet50 pretrain model to faster_rcnn backbone pretrain model
|
||||
"""
|
||||
import argparse
|
||||
from mindspore.train.serialization import load_checkpoint, save_checkpoint
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
parser = argparse.ArgumentParser(description='load_ckpt')
|
||||
parser.add_argument('--ckpt_file', type=str, default='', help='ckpt file path')
|
||||
args_opt = parser.parse_args()
|
||||
def load_weights(model_path, use_fp16_weight):
|
||||
"""
|
||||
load resnet50 pretrain checkpoint file.
|
||||
|
||||
Args:
|
||||
model_path (str): resnet50 pretrain checkpoint file .
|
||||
use_fp16_weight(bool): whether save weight into float16.
|
||||
|
||||
Returns:
|
||||
parameter list(list): pretrain model weight list.
|
||||
"""
|
||||
ms_ckpt = load_checkpoint(model_path)
|
||||
weights = {}
|
||||
for msname in ms_ckpt:
|
||||
if msname.startswith("layer") or msname.startswith("conv1") or msname.startswith("bn"):
|
||||
param_name = "backbone." + msname
|
||||
else:
|
||||
param_name = msname
|
||||
if "down_sample_layer.0" in param_name:
|
||||
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
|
||||
if "down_sample-layer.1" in param_name:
|
||||
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
|
||||
weights[param_name] = ms_ckpt[msname].data.asnumpy()
|
||||
if use_fp16_weight:
|
||||
dtype = mstype.float16
|
||||
else:
|
||||
dtype = mstype.float32
|
||||
parameter_dict = {}
|
||||
for name in weights:
|
||||
parameter_dict[name] = Parameter(Tensor(weights[name], dtype), name=name)
|
||||
param_list = []
|
||||
for key, value in parameter_dict.items():
|
||||
param_list.append({"name": key, "data": value})
|
||||
return param_list
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=True)
|
||||
save_checkpoint(parameter_list, "resnet50_backbone.ckpt")
|
Loading…
Reference in New Issue