This commit is contained in:
huchunmei 2021-05-25 15:28:51 +08:00
parent e513786709
commit 18c3cbac05
22 changed files with 798 additions and 292 deletions

View File

@ -155,17 +155,17 @@ bash scripts/docker_start.sh maskrcnn:20.1.0 [DATA_DIR] [MODEL_DIR]
```shell
# standalone training
bash run_standalone_train.sh [PRETRAINED_CKPT]
bash run_standalone_train.sh [PRETRAINED_CKPT] [DATA_PATH]
# distributed training
bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT]
bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT] [DATA_PATH]
```
4. Eval
```shell
# Evaluation
bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH] [DATA_PATH]
```
5. Inference.
@ -203,12 +203,17 @@ bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
├─resnet50.py # backbone network
├─roi_align.py # roi align network
└─rpn.py # reagion proposal network
├─config.py # network configuration
├─convert_checkpoint.py # convert resnet50 backbone checkpoint
├─dataset.py # dataset utils
├─lr_schedule.py # leanring rate geneatore
├─network_define.py # network define for maskrcnn
└─util.py # routine operation
├─util.py # routine operation
└─model_utils
├─config.py # Processing configuration parameters
├─device_adapter.py # Get cloud ID
├─local_adapter.py # Get local ID
└─moxing_adapter.py # Parameter processing
├─default_config.yaml # Training parameter profile
├─mindspore_hub_conf.py # mindspore hub interface
├─export.py #script to export AIR,MINDIR,ONNX model
├─eval.py # evaluation scripts

View File

@ -197,12 +197,17 @@ bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
├─resnet50.py # 骨干网
├─roi_align.py # 兴趣点对齐网络
└─rpn.py # 区域候选网络
├─config.py # 网络配置
├─convert_checkpoint.py # 转换预训练checkpoint文件
├─dataset.py # 数据集工具
├─lr_schedule.py # 学习率生成器
├─network_define.py # MaskRCNN的网络定义
└─util.py # 例行操作
├─util.py # 例行操作
└─model_utils
├─config.py # 训练配置
├─device_adapter.py # 获取云上id
├─local_adapter.py # 获取本地id
└─moxing_adapter.py # 参数处理
├─default_config.yaml # 训练参数配置文件
├─mindspore_hub_conf.py # MindSpore hub接口
├─export.py #导出 AIR,MINDIR,ONNX模型的脚本
├─eval.py # 评估脚本

View File

@ -28,9 +28,9 @@
#include "include/api/model.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/vision.h"
#include "include/dataset/transforms.h"
#include "include/dataset/execute.h"
#include "include/minddata/dataset/include/vision.h"
#include "include/minddata/dataset/include/transforms.h"
#include "include/minddata/dataset/include/execute.h"
#include "../inc/utils.h"
using mindspore::Context;

View File

@ -0,0 +1,217 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
checkpoint_path: './checkpoint/'
checkpoint_file: './checkpoint/mask_rcnn-12_7393.ckpt'
device_target: Ascend
enable_profiling: False
pre_trained: "/cache/data"
coco_root: "/cache/data"
ckpt_path: './ckpt_maskrcnn/mask_rcnn-12_7393.ckpt'
ckpt_file: '/cache/data/cocodataset/ckpt_maskrcnn/mask_rcnn-12_7393.ckpt'
ann_file: "./annotations/instances_val2017.json"
# ==============================================================================
modelarts_dataset_unzip_name: 'cocodataset'
need_modelarts_dataset_unzip: True
img_path: '' # "image file path."
result_path: '' # "result file path."
# Training options
img_width: 1280
img_height: 768
keep_ratio: True
flip_ratio: 0.5
expand_ratio: 1.0
max_instance_count: 128
mask_shape: (28, 28)
# anchor
feature_shapes: [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)]
anchor_scales: [8]
anchor_ratios: [0.5, 1.0, 2.0]
anchor_strides: [4, 8, 16, 32, 64]
num_anchors: 3
# resnet
resnet_block: [3, 4, 6, 3]
resnet_in_channels: [64, 256, 512, 1024]
resnet_out_channels: [256, 512, 1024, 2048]
# fpn
fpn_in_channels: [256, 512, 1024, 2048]
fpn_out_channels: 256
fpn_num_outs: 5
# rpn
rpn_in_channels: 256
rpn_feat_channels: 256
rpn_loss_cls_weight: 1.0
rpn_loss_reg_weight: 1.0
rpn_cls_out_channels: 1
rpn_target_means: [0., 0., 0., 0.]
rpn_target_stds: [1.0, 1.0, 1.0, 1.0]
# bbox_assign_sampler
neg_iou_thr: 0.3
pos_iou_thr: 0.7
min_pos_iou: 0.3
num_bboxes: 245520
num_gts: 128
num_expected_neg: 256
num_expected_pos: 128
# proposal
activate_num_classes: 2
use_sigmoid_cls: True
# roi_align
roi_layer: dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2)
roi_align_out_channels: 256
roi_align_featmap_strides: [4, 8, 16, 32]
roi_align_finest_scale: 56
roi_sample_num: 640
# bbox_assign_sampler_stage2
neg_iou_thr_stage2: 0.5
pos_iou_thr_stage2: 0.5
min_pos_iou_stage2: 0.5
num_bboxes_stage2: 2000
num_expected_pos_stage2: 128
num_expected_neg_stage2: 512
num_expected_total_stage2: 512
# rcnn
rcnn_num_layers: 2
rcnn_in_channels: 256
rcnn_fc_out_channels: 1024
rcnn_mask_out_channels: 256
rcnn_loss_cls_weight: 1
rcnn_loss_reg_weight: 1
rcnn_loss_mask_fb_weight: 1
rcnn_target_means: [0., 0., 0., 0.]
rcnn_target_stds: [0.1, 0.1, 0.2, 0.2]
# train proposal
rpn_proposal_nms_across_levels: False
rpn_proposal_nms_pre: 2000
rpn_proposal_nms_post: 2000
rpn_proposal_max_num: 2000
rpn_proposal_nms_thr: 0.7
rpn_proposal_min_bbox_size: 0
# test proposal
rpn_nms_across_levels: False
rpn_nms_pre: 1000
rpn_nms_post: 1000
rpn_max_num: 1000
rpn_nms_thr: 0.7
rpn_min_bbox_min_size: 0
test_score_thr: 0.05
test_iou_thr: 0.5
test_max_per_img: 100
test_batch_size: 2
rpn_head_use_sigmoid: True
rpn_head_weight: 1.0
mask_thr_binary: 0.5
# LR
base_lr: 0.02
base_step: 58633
total_epoch: 13
warmup_step: 500
warmup_ratio: 1/3.0
sgd_momentum: 0.9
# train
batch_size: 2
loss_scale: 1
momentum: 0.91
weight_decay: 0.0001 # 1e-4
pretrain_epoch_size: 0
epoch_size: 12
save_checkpoint: True
save_checkpoint_epochs: 1
keep_checkpoint_max: 12
save_checkpoint_path: "./"
mindrecord_dir: "./MindRecord_COCO" # "/home/mask_rcnn/MindRecord_COCO2017_Train"
train_data_type: "train2017"
val_data_type: "val2017"
instance_set: "annotations/instances_{}.json"
coco_classes: ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')
num_classes: 81
only_create_dataset: False
run_distribute: False
do_train: True
do_eval: False
dataset: "coco"
device_id: 0
device_num: 1
rank_id: 0
# batch_size_export: 1
file_name: "maskrcnn"
file_format: "AIR"
# other
learning_rate: 0.002
buffer_size: 1000
save_checkpoint_steps: 1562
sink_size: -1
dataset_sink_mode: True
lr: 0.01
# Model Description
model_name: maskrcnn
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
ann_file: 'Ann file, default is val.json.'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
run_distribute: 'Run distribute, default is false.'
do_train: 'Do train or not, default is true.'
do_eval: 'Do eval or not, default is false.'
dataset: 'Dataset, default is coco.'
pre_trained: 'Pretrain file path.'
device_id: 'Device id, default is 0.'
device_num: 'Use device nums, default is 1.'
rank_id: 'Rank id, default is 0.'
file_format: 'file format'
img_path: "image file path."
result_path: "result file path."
---
device_target: ['Ascend', 'GPU', 'CPU']
file_format: ["AIR", "ONNX", "MINDIR"]

View File

@ -15,29 +15,31 @@
"""Evaluation for MaskRcnn"""
import os
import argparse
import time
import re
import numpy as np
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
from pycocotools.coco import COCO
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
set_seed(1)
parser = argparse.ArgumentParser(description="MaskRcnn evaluation")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args()
lss = [int(re.findall(r'[0-9]+', i)[0]) for i in config.feature_shapes]
config.feature_shapes = [(lss[2*i], lss[2*i+1]) for i in range(int(len(lss)/2))]
config.roi_layer = dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2)
config.warmup_ratio = 1/3.0
config.mask_shape = (28, 28)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
def maskrcnn_eval(dataset_path, ckpt_path, ann_file):
"""MaskRcnn evaluation."""
@ -106,14 +108,77 @@ def maskrcnn_eval(dataset_path, ckpt_path, ann_file):
result_files = results2json(dataset_coco, outputs, "./results.pkl")
coco_eval(result_files, eval_types, dataset_coco, single_result=False)
if __name__ == '__main__':
def modelarts_process():
""" modelarts process """
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
int(int(time.time() - s_time) % 60)))
print("Extract Done")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
print("#" * 200, os.listdir(save_dir_1))
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
config.coco_root = config.dataset_path
config.checkpoint_path = os.path.join(config.dataset_path, config.ckpt_path)
config.ann_file = os.path.join(config.dataset_path, config.ann_file)
config.mindrecord_dir = os.path.join(config.dataset_path, config.mindrecord_dir)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
@moxing_wrapper(pre_process=modelarts_process)
def eval_():
prefix = "MaskRcnn_eval.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if config.dataset == "coco":
if os.path.isdir(config.coco_root):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1)
@ -129,5 +194,9 @@ if __name__ == '__main__':
print("IMAGE_DIR or ANNO_PATH not exits.")
print("Start Eval!")
maskrcnn_eval(mindrecord_file, args_opt.checkpoint_path, args_opt.ann_file)
print("ckpt_path=", args_opt.checkpoint_path)
maskrcnn_eval(mindrecord_file, config.checkpoint_path, config.ann_file)
print("ckpt_path=", config.checkpoint_path)
if __name__ == '__main__':
eval_()

View File

@ -13,31 +13,28 @@
# limitations under the License.
# ============================================================================
"""export checkpoint file into air, onnx, mindir models"""
import argparse
import re
import numpy as np
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
from src.maskrcnn.mask_rcnn_r50 import MaskRcnn_Infer
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.maskrcnn.mask_rcnn_r50 import MaskRcnn_Infer
from src.config import config
lss = [int(re.findall(r'[0-9]+', i)[0]) for i in config.feature_shapes]
config.feature_shapes = [(lss[2*i], lss[2*i+1]) for i in range(int(len(lss)/2))]
config.roi_layer = dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2)
config.warmup_ratio = 1/3.0
config.mask_shape = (28, 28)
parser = argparse.ArgumentParser(description='maskrcnn export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="maskrcnn", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
if __name__ == '__main__':
net = MaskRcnn_Infer(config=config)
param_dict = load_checkpoint(args.ckpt_file)
param_dict = load_checkpoint(config.ckpt_file)
param_dict_new = {}
for key, value in param_dict.items():
@ -48,8 +45,8 @@ if __name__ == '__main__':
bs = config.test_batch_size
img = Tensor(np.zeros([args.batch_size, 3, config.img_height, config.img_width], np.float16))
img_metas = Tensor(np.zeros([args.batch_size, 4], np.float16))
img = Tensor(np.zeros([config.batch_size, 3, config.img_height, config.img_width], np.float16))
img_metas = Tensor(np.zeros([config.batch_size, 4], np.float16))
input_data = [img, img_metas]
export(net, *input_data, file_name=args.file_name, file_format=args.file_format)
export(net, *input_data, file_name=config.file_name, file_format=config.file_format)

View File

@ -14,7 +14,7 @@
# ============================================================================
"""hub config."""
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.config import config
from src.model_utils.config import config
def create_network(name, *args, **kwargs):
if name == "maskrcnn":

View File

@ -14,23 +14,16 @@
# ============================================================================
"""post process for 310 inference"""
import os
import argparse
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
from src.config import config
from src.model_utils.config import config
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
dst_width = 1280
dst_height = 768
parser = argparse.ArgumentParser(description="maskrcnn inference")
parser.add_argument("--ann_file", type=str, required=True, help="ann file.")
parser.add_argument("--img_path", type=str, required=True, help="image file path.")
parser.add_argument("--result_path", type=str, required=True, help="result file path.")
args = parser.parse_args()
def get_img_size(file_name):
img = Image.open(file_name)
return img.size
@ -96,4 +89,4 @@ def get_eval_result(ann_file, img_path, result_path):
coco_eval(result_files, eval_types, dataset_coco, single_result=False)
if __name__ == '__main__':
get_eval_result(args.ann_file, args.img_path, args.result_path)
get_eval_result(config.ann_file, config.img_path, config.result_path)

View File

@ -14,9 +14,9 @@
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
if [ $# != 3 ]
then
echo "Usage: bash run_train.sh [RANK_TABLE_FILE] [PRETRAINED_PATH]"
echo "Usage: bash run_train.sh [RANK_TABLE_FILE] [PRETRAINED_PATH] [DATA_PATH]"
exit 1
fi
@ -29,9 +29,11 @@ get_real_path(){
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PATH3=$(get_real_path $3)
echo $PATH1
echo $PATH2
echo $PATH3
if [ ! -f $PATH1 ]
then
@ -68,12 +70,13 @@ do
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp ../*.yaml ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python train.py --do_train=True --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM \
--pre_trained=$PATH2 &> log &
--pre_trained=$PATH2 --data_path=$PATH3 &> log &
cd ..
done

View File

@ -14,9 +14,9 @@
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
if [ $# != 3 ]
then
echo "Usage: bash run_eval.sh [ANN_FILE] [CHECKPOINT_PATH]"
echo "Usage: bash run_eval.sh [ANN_FILE] [CHECKPOINT_PATH] [DATA_PATH]"
exit 1
fi
@ -29,8 +29,10 @@ get_real_path(){
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PATH3=$(get_real_path $3)
echo $PATH1
echo $PATH2
echo $PATH3
if [ ! -f $PATH1 ]
then
@ -56,10 +58,12 @@ then
fi
mkdir ./eval
cp ../*.py ./eval
cp ../*.yaml ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
env > env.log_eval
echo "start eval for device $DEVICE_ID"
python eval.py --device_id=$DEVICE_ID --ann_file=$PATH1 --checkpoint_path=$PATH2 &> log &
python ./eval.py --device_id=$DEVICE_ID --ann_file=$PATH1 --checkpoint_path=$PATH2 \
--data_path=$PATH3 &> log_eval.txt &
cd ..

View File

@ -14,9 +14,9 @@
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
if [ $# != 2 ]
then
echo "Usage: bash run_standalone_train.sh [PRETRAINED_PATH]"
echo "Usage: bash run_standalone_train.sh [PRETRAINED_PATH] [DATA_PATH]"
exit 1
fi
@ -28,7 +28,9 @@ get_real_path(){
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
echo $PATH1
echo $PATH2
if [ ! -f $PATH1 ]
then
@ -48,10 +50,11 @@ then
fi
mkdir ./train
cp ../*.py ./train
cp ../*.yaml ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --do_train=True --device_id=$DEVICE_ID --pre_trained=$PATH1 &> log &
python ./train.py --do_train=True --device_id=$DEVICE_ID --pre_trained=$PATH1 --data_path=$PATH2 &> log.txt &
cd ..

View File

@ -1,161 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" :===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"img_width": 1280,
"img_height": 768,
"keep_ratio": True,
"flip_ratio": 0.5,
"expand_ratio": 1.0,
"max_instance_count": 128,
"mask_shape": (28, 28),
# anchor
"feature_shapes": [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)],
"anchor_scales": [8],
"anchor_ratios": [0.5, 1.0, 2.0],
"anchor_strides": [4, 8, 16, 32, 64],
"num_anchors": 3,
# resnet
"resnet_block": [3, 4, 6, 3],
"resnet_in_channels": [64, 256, 512, 1024],
"resnet_out_channels": [256, 512, 1024, 2048],
# fpn
"fpn_in_channels": [256, 512, 1024, 2048],
"fpn_out_channels": 256,
"fpn_num_outs": 5,
# rpn
"rpn_in_channels": 256,
"rpn_feat_channels": 256,
"rpn_loss_cls_weight": 1.0,
"rpn_loss_reg_weight": 1.0,
"rpn_cls_out_channels": 1,
"rpn_target_means": [0., 0., 0., 0.],
"rpn_target_stds": [1.0, 1.0, 1.0, 1.0],
# bbox_assign_sampler
"neg_iou_thr": 0.3,
"pos_iou_thr": 0.7,
"min_pos_iou": 0.3,
"num_bboxes": 245520,
"num_gts": 128,
"num_expected_neg": 256,
"num_expected_pos": 128,
# proposal
"activate_num_classes": 2,
"use_sigmoid_cls": True,
# roi_align
"roi_layer": dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2),
"roi_align_out_channels": 256,
"roi_align_featmap_strides": [4, 8, 16, 32],
"roi_align_finest_scale": 56,
"roi_sample_num": 640,
# bbox_assign_sampler_stage2
"neg_iou_thr_stage2": 0.5,
"pos_iou_thr_stage2": 0.5,
"min_pos_iou_stage2": 0.5,
"num_bboxes_stage2": 2000,
"num_expected_pos_stage2": 128,
"num_expected_neg_stage2": 512,
"num_expected_total_stage2": 512,
# rcnn
"rcnn_num_layers": 2,
"rcnn_in_channels": 256,
"rcnn_fc_out_channels": 1024,
"rcnn_mask_out_channels": 256,
"rcnn_loss_cls_weight": 1,
"rcnn_loss_reg_weight": 1,
"rcnn_loss_mask_fb_weight": 1,
"rcnn_target_means": [0., 0., 0., 0.],
"rcnn_target_stds": [0.1, 0.1, 0.2, 0.2],
# train proposal
"rpn_proposal_nms_across_levels": False,
"rpn_proposal_nms_pre": 2000,
"rpn_proposal_nms_post": 2000,
"rpn_proposal_max_num": 2000,
"rpn_proposal_nms_thr": 0.7,
"rpn_proposal_min_bbox_size": 0,
# test proposal
"rpn_nms_across_levels": False,
"rpn_nms_pre": 1000,
"rpn_nms_post": 1000,
"rpn_max_num": 1000,
"rpn_nms_thr": 0.7,
"rpn_min_bbox_min_size": 0,
"test_score_thr": 0.05,
"test_iou_thr": 0.5,
"test_max_per_img": 100,
"test_batch_size": 2,
"rpn_head_use_sigmoid": True,
"rpn_head_weight": 1.0,
"mask_thr_binary": 0.5,
# LR
"base_lr": 0.02,
"base_step": 58633,
"total_epoch": 13,
"warmup_step": 500,
"warmup_ratio": 1/3.0,
"sgd_momentum": 0.9,
# train
"batch_size": 2,
"loss_scale": 1,
"momentum": 0.91,
"weight_decay": 1e-4,
"pretrain_epoch_size": 0,
"epoch_size": 12,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 12,
"save_checkpoint_path": "./",
"mindrecord_dir": "/home/mask_rcnn/MindRecord_COCO2017_Train",
"coco_root": "/home/mask_rcnn/coco2017/",
"train_data_type": "train2017",
"val_data_type": "val2017",
"instance_set": "annotations/instances_{}.json",
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81
})

View File

@ -15,15 +15,13 @@
"""
convert resnet50 pretrain model to faster_rcnn backbone pretrain model
"""
import argparse
from mindspore.train.serialization import load_checkpoint, save_checkpoint
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from model_utils.config import config
parser = argparse.ArgumentParser(description='load_ckpt')
parser.add_argument('--ckpt_file', type=str, default='', help='ckpt file path')
args_opt = parser.parse_args()
def load_weights(model_path, use_fp16_weight):
"""
load resnet50 pretrain checkpoint file.
@ -60,5 +58,5 @@ def load_weights(model_path, use_fp16_weight):
return param_list
if __name__ == "__main__":
parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=False)
parameter_list = load_weights(config.ckpt_file, use_fp16_weight=False)
save_checkpoint(parameter_list, "resnet50_backbone.ckpt")

View File

@ -17,15 +17,25 @@
from __future__ import division
import os
import re
import numpy as np
from numpy import random
import cv2
import mmcv
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from mindspore.mindrecord import FileWriter
from src.config import config
from model_utils.config import config
config.mask_shape = (28, 28)
if config.enable_modelarts and config.need_modelarts_dataset_unzip:
config.coco_root = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
else:
config.coco_root = config.data_path
config.mindrecord_dir = os.path.join(config.coco_root, config.mindrecord_dir)
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
"""Calculate the ious between each bbox of bboxes1 and bboxes2.
@ -385,8 +395,12 @@ def create_coco_label(is_training):
if is_training:
data_type = config.train_data_type
#Classes need to train or test.
train_cls = config.coco_classes
# Classes need to train or test.
# train_cls = config.coco_classes
train_cls = [i for i in re.findall(r'[a-zA-Z\s]+', config.coco_classes) if i != ' ']
train_cls = np.array(train_cls)
print(train_cls)
train_cls_dict = {}
for i, cls in enumerate(train_cls):
train_cls_dict[cls] = i

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,122 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -20,7 +20,7 @@ from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as maskUtils
from src.config import config
from model_utils.config import config
_init_value = np.array(0.0)
summary_init = {

View File

@ -15,10 +15,17 @@
"""train MaskRcnn and get checkpoint files."""
import os
import time
import argparse
import ast
import os
import re
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.lr_schedule import dynamic_lr
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
@ -31,32 +38,78 @@ from mindspore.nn import Momentum
from mindspore.common import set_seed
from mindspore.communication.management import get_rank, get_group_size
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.lr_schedule import dynamic_lr
set_seed(1)
parser = argparse.ArgumentParser(description="MaskRcnn training")
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, help="If set it true, only create "
"Mindrecord, default is false.")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default is false.")
parser.add_argument("--do_train", type=ast.literal_eval, default=True, help="Do train or not, default is true.")
parser.add_argument("--do_eval", type=ast.literal_eval, default=False, help="Do eval or not, default is false.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--pre_trained", type=str, default="", help="Pretrain file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.")
args_opt = parser.parse_args()
lss = [int(re.findall(r'[0-9]+', i)[0]) for i in config.feature_shapes]
config.feature_shapes = [(lss[2*i], lss[2*i+1]) for i in range(int(len(lss)/2))]
config.roi_layer = dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2)
config.warmup_ratio = 1/3.0
config.mask_shape = (28, 28)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
def modelarts_pre_process():
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
int(int(time.time() - s_time) % 60)))
print("Extract Done")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if __name__ == '__main__':
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
print("#" * 200, os.listdir(save_dir_1))
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
config.pre_trained = os.path.join(config.dataset_path, config.ckpt_path)
config.save_checkpoint_path = config.output_path
config.mindrecord_dir = os.path.join(config.dataset_path, config.mindrecord_dir)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_maskrcnn():
print("Start train for maskrcnn!")
if not args_opt.do_eval and args_opt.run_distribute:
if not config.do_eval and config.run_distribute:
init()
rank = get_rank()
device_num = get_group_size()
@ -68,7 +121,7 @@ if __name__ == '__main__':
print("Start create dataset!")
# It will generate mindrecord file in args_opt.mindrecord_dir,
# It will generate mindrecord file in config.mindrecord_dir,
# and the file name is MaskRcnn.mindrecord0, 1, ... file_num.
prefix = "MaskRcnn.mindrecord"
mindrecord_dir = config.mindrecord_dir
@ -76,13 +129,13 @@ if __name__ == '__main__':
if rank == 0 and not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.coco_root):
if config.dataset == "coco":
if os.path.isdir(config.data_path):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
raise Exception("coco_root not exits.")
raise Exception("data_path not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
@ -93,9 +146,8 @@ if __name__ == '__main__':
while not os.path.exists(mindrecord_file+".db"):
time.sleep(5)
if not args_opt.only_create_dataset:
loss_scale = float(config.loss_scale)
if not config.only_create_dataset:
# loss_scale = float(config.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as MaskRcnn.mindrecord0.
dataset = create_maskrcnn_dataset(mindrecord_file, batch_size=config.batch_size,
device_num=device_num, rank_id=rank)
@ -107,7 +159,7 @@ if __name__ == '__main__':
net = Mask_Rcnn_Resnet50(config=config)
net = net.set_train()
load_path = args_opt.pre_trained
load_path = config.pre_trained
if load_path != "":
param_dict = load_checkpoint(load_path)
if config.pretrain_epoch_size == 0:
@ -123,7 +175,7 @@ if __name__ == '__main__':
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
if config.run_distribute:
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
@ -141,3 +193,7 @@ if __name__ == '__main__':
model = Model(net)
model.train(config.epoch_size, dataset, callbacks=cb)
if __name__ == '__main__':
train_maskrcnn()

View File

@ -16,13 +16,8 @@
import os
import pytest
import numpy as np
from model_zoo.official.cv.maskrcnn.src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from model_zoo.official.cv.maskrcnn.src.config import config
from tests.st.model_zoo_tests import utils
from mindspore import Tensor, context, export
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@ -32,24 +27,20 @@ def test_maskrcnn_export():
"""
export maskrcnn air.
"""
net = Mask_Rcnn_Resnet50(config=config)
net.set_train(False)
old_list = ["(config=config)", "(net, param_dict_new)"]
new_list = ["(config=config\\n) '''", "(net, param_dict_new)\\n '''"]
bs = config.test_batch_size
img = Tensor(np.zeros([bs, 3, 768, 1280], np.float16))
img_metas = Tensor(np.zeros([bs, 4], np.float16))
gt_bboxes = Tensor(np.zeros([bs, 128, 4], np.float16))
gt_labels = Tensor(np.zeros([bs, 128], np.int32))
gt_num = Tensor(np.zeros([bs, 128], np.bool))
gt_mask = Tensor(np.zeros([bs, 128], np.bool))
input_data = [img, img_metas, gt_bboxes, gt_labels, gt_num, gt_mask]
export(net, *input_data, file_name="maskrcnn", file_format="AIR")
file_name = "maskrcnn.air"
assert os.path.exists(file_name)
os.remove(file_name)
cur_path = os.getcwd()
model_path = "{}/../../../../model_zoo/official/cv".format(cur_path)
model_name = "maskrcnn"
utils.copy_files(model_path, cur_path, model_name)
cur_model_path = os.path.join(cur_path, model_name)
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "export.py"))
# ckpt_path = os.path.join(utils.ckpt_root, "bgcf/bgcf_trained.ckpt")
exec_export_shell = "cd {}; python export.py --config_path default_config.yaml".format(model_name)
os.system(exec_export_shell)
assert os.path.exists(os.path.join(cur_model_path, "{}.air".format(model_name)))
if __name__ == '__main__':
test_maskrcnn_export()