forked from mindspore-Ecosystem/mindspore
!22654 add fasterrcnn demo for cross-silo federated
Merge pull request !22654 from zhangqi/0831
This commit is contained in:
commit
90e807f066
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"recovery": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "recovery.json"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: CPU
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
# config
|
||||
img_width: 1280
|
||||
img_height: 768
|
||||
keep_ratio: True
|
||||
flip_ratio: 0.5
|
||||
expand_ratio: 1.0
|
||||
|
||||
# anchor
|
||||
feature_shapes:
|
||||
- [192, 320]
|
||||
- [96, 160]
|
||||
- [48, 80]
|
||||
- [24, 40]
|
||||
- [12, 20]
|
||||
anchor_scales: [8]
|
||||
anchor_ratios: [0.5, 1.0, 2.0]
|
||||
anchor_strides: [4, 8, 16, 32, 64]
|
||||
num_anchors: 3
|
||||
|
||||
# resnet
|
||||
resnet_block: [3, 4, 6, 3]
|
||||
resnet_in_channels: [64, 256, 512, 1024]
|
||||
resnet_out_channels: [256, 512, 1024, 2048]
|
||||
|
||||
# fpn
|
||||
fpn_in_channels: [256, 512, 1024, 2048]
|
||||
fpn_out_channels: 256
|
||||
fpn_num_outs: 5
|
||||
|
||||
# rpn
|
||||
rpn_in_channels: 256
|
||||
rpn_feat_channels: 256
|
||||
rpn_loss_cls_weight: 1.0
|
||||
rpn_loss_reg_weight: 1.0
|
||||
rpn_cls_out_channels: 1
|
||||
rpn_target_means: [0., 0., 0., 0.]
|
||||
rpn_target_stds: [1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# bbox_assign_sampler
|
||||
neg_iou_thr: 0.3
|
||||
pos_iou_thr: 0.7
|
||||
min_pos_iou: 0.3
|
||||
num_bboxes: 245520
|
||||
num_gts: 128
|
||||
num_expected_neg: 256
|
||||
num_expected_pos: 128
|
||||
|
||||
# proposal
|
||||
activate_num_classes: 2
|
||||
use_sigmoid_cls: True
|
||||
|
||||
# roi_align
|
||||
roi_layer: {type: 'RoIAlign', out_size: 7, sample_num: 2}
|
||||
roi_align_out_channels: 256
|
||||
roi_align_featmap_strides: [4, 8, 16, 32]
|
||||
roi_align_finest_scale: 56
|
||||
roi_sample_num: 640
|
||||
|
||||
# bbox_assign_sampler_stage2
|
||||
neg_iou_thr_stage2: 0.5
|
||||
pos_iou_thr_stage2: 0.5
|
||||
min_pos_iou_stage2: 0.5
|
||||
num_bboxes_stage2: 2000
|
||||
num_expected_pos_stage2: 128
|
||||
num_expected_neg_stage2: 512
|
||||
num_expected_total_stage2: 512
|
||||
|
||||
# rcnn
|
||||
rcnn_num_layers: 2
|
||||
rcnn_in_channels: 256
|
||||
rcnn_fc_out_channels: 1024
|
||||
rcnn_loss_cls_weight: 1
|
||||
rcnn_loss_reg_weight: 1
|
||||
rcnn_target_means: [0., 0., 0., 0.]
|
||||
rcnn_target_stds: [0.1, 0.1, 0.2, 0.2]
|
||||
|
||||
# train proposal
|
||||
rpn_proposal_nms_across_levels: False
|
||||
rpn_proposal_nms_pre: 2000
|
||||
rpn_proposal_nms_post: 2000
|
||||
rpn_proposal_max_num: 2000
|
||||
rpn_proposal_nms_thr: 0.7
|
||||
rpn_proposal_min_bbox_size: 0
|
||||
|
||||
# test proposal
|
||||
rpn_nms_across_levels: False
|
||||
rpn_nms_pre: 1000
|
||||
rpn_nms_post: 1000
|
||||
rpn_max_num: 1000
|
||||
rpn_nms_thr: 0.7
|
||||
rpn_min_bbox_min_size: 0
|
||||
test_score_thr: 0.05
|
||||
test_iou_thr: 0.5
|
||||
test_max_per_img: 100
|
||||
test_batch_size: 2
|
||||
|
||||
rpn_head_use_sigmoid: True
|
||||
rpn_head_weight: 1.0
|
||||
|
||||
# LR
|
||||
base_lr: 0.01
|
||||
warmup_step: 500
|
||||
warmup_ratio: 0.0625
|
||||
sgd_step: [8, 11]
|
||||
sgd_momentum: 0.9
|
||||
|
||||
# train
|
||||
batch_size: 4
|
||||
loss_scale: 256
|
||||
momentum: 0.91
|
||||
weight_decay: 0.00001
|
||||
epoch_size: 20
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 5
|
||||
keep_checkpoint_max: 20
|
||||
save_checkpoint_path: "./"
|
||||
|
||||
# fl
|
||||
server_mode: "FEDERATED_LEARNING"
|
||||
ms_role: "MS_WORKER"
|
||||
worker_num: 1
|
||||
server_num: 1
|
||||
scheduler_ip: "10.113.216.40"
|
||||
scheduler_port: 8113
|
||||
fl_server_port: 6666
|
||||
start_fl_job_threshold: 1
|
||||
start_fl_job_time_window: 3000
|
||||
update_model_ratio: 1.0
|
||||
update_model_time_window: 3000
|
||||
fl_name: "Lenet"
|
||||
fl_iteration_num: 25
|
||||
client_epoch_num: 20
|
||||
client_batch_size: 4
|
||||
client_learning_rate: 0.01
|
||||
worker_step_num_per_iteration: 65
|
||||
scheduler_manage_port: 11202
|
||||
config_file_path: ""
|
||||
encrypt_type: "NOT_ENCRYPT"
|
||||
dataset_path: ""
|
||||
user_id: 0
|
||||
|
||||
|
||||
# Number of threads used to process the dataset in parallel
|
||||
num_parallel_workers: 8
|
||||
# Parallelize Python operations with multiple worker processes
|
||||
python_multiprocessing: True
|
||||
mindrecord_dir: "./datasets/coco_split/split_1000/mindrecord_9"
|
||||
coco_root: "./datasets/coco2017/coco2017"
|
||||
train_data_type: "train2017"
|
||||
val_data_type: "val2017"
|
||||
#instance_set: "annotations/instances_{}.json"
|
||||
instance_set: "./datasets/coco_split/split_1000/train_9.json"
|
||||
coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush']
|
||||
num_classes: 81
|
||||
|
||||
# train.py FasterRcnn training
|
||||
run_distribute: False
|
||||
dataset: "coco"
|
||||
#pre_trained: "./cache/train/fasterrcnn/faster_rcnn-12_7393.ckpt"
|
||||
pre_trained: "./resnet50_backbone.ckpt"
|
||||
device_id: 0
|
||||
device_num: 1
|
||||
rank_id: 0
|
||||
image_dir: ''
|
||||
anno_path: ''
|
||||
backbone: 'resnet_v1_50'
|
||||
|
||||
# eval.py FasterRcnn evaluation
|
||||
ann_file: './cache/data/annotations/instances_val2017.json'
|
||||
checkpoint_path: "./cache/train/fasterrcnn/faster_rcnn-12_7393.ckpt"
|
||||
|
||||
# export.py fasterrcnn_export
|
||||
file_name: "faster_rcnn"
|
||||
file_format: "AIR"
|
||||
ckpt_file: "./cache/train/fasterrcnn/faster_rcnn-12_7393.ckpt"
|
||||
|
||||
# postprocess ("./src/config_50.yaml")
|
||||
#ann_file: ''
|
||||
result_path: ''
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
result_dir: "result files path."
|
||||
label_dir: "image file path."
|
||||
|
||||
device_target: "device where the code will be implemented, default is Ascend"
|
||||
file_name: "output file name."
|
||||
dataset: "Dataset, either cifar10 or imagenet2012"
|
||||
parameter_server: 'Run parameter server train'
|
||||
width: 'input width'
|
||||
height: 'input height'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
|
||||
run_distribute: 'Run distribute, default is false.'
|
||||
do_train: 'Do train or not, default is true.'
|
||||
do_eval: 'Do eval or not, default is false.'
|
||||
pre_trained: 'Pretrained checkpoint path'
|
||||
device_id: 'Device id, default is 0.'
|
||||
device_num: 'Use device nums, default is 1.'
|
||||
rank_id: 'Rank id, default is 0.'
|
||||
file_format: 'file format'
|
||||
ann_file: "Ann file, default is val.json."
|
||||
checkpoint_path: "Checkpoint file path."
|
||||
ckpt_file: 'fasterrcnn ckpt file.'
|
||||
result_path: "result file path."
|
||||
backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152"
|
||||
|
||||
---
|
||||
device_target: ['GPU', 'CPU', 'Ascend']
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
||||
dataset_name: ["cifar10", "imagenet2012"]
|
|
@ -0,0 +1,30 @@
|
|||
|
||||
#copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Finish test_cross_silo_femnist.py case")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
scheduler_port = args.scheduler_port
|
||||
|
||||
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
|
||||
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
|
||||
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd])
|
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
from mindspore.common import set_seed
|
||||
from src.dataset import data_to_mindrecord_byte_image
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
def generate_coco_mindrecord():
|
||||
""" train_fasterrcnn_ """
|
||||
|
||||
# It will generate mindrecord file in config.mindrecord_dir,
|
||||
# and the file name is FasterRcnn.mindrecord0, 1, ... file_num.
|
||||
prefix = "FasterRcnn.mindrecord"
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix)
|
||||
print("CHECKING MINDRECORD FILES ...")
|
||||
|
||||
if rank == 0 and not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if config.dataset == "coco":
|
||||
if os.path.isdir(config.coco_root):
|
||||
if not os.path.exists(config.coco_root):
|
||||
print("Please make sure config:coco_root is valid.")
|
||||
raise ValueError(config.coco_root)
|
||||
print("Create Mindrecord. It may take some time.")
|
||||
data_to_mindrecord_byte_image(config, "coco", True, prefix, 1)
|
||||
# data_to_mindrecord_byte_image(config, "coco", True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
|
||||
if not os.path.exists(config.image_dir):
|
||||
print("Please make sure config:image_dir is valid.")
|
||||
raise ValueError(config.image_dir)
|
||||
print("Create Mindrecord. It may take some time.")
|
||||
data_to_mindrecord_byte_image(config, "other", True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("image_dir or anno_path not exits.")
|
||||
|
||||
while not os.path.exists(mindrecord_file + ".db"):
|
||||
time.sleep(5)
|
||||
|
||||
print("CHECKING MINDRECORD FILES DONE!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate_coco_mindrecord()
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
if config.backbone in ("resnet_v1.5_50", "resnet_v1_101", "resnet_v1_152"):
|
||||
from src.FasterRcnn.faster_rcnn_resnet import Faster_Rcnn_Resnet
|
||||
elif config.backbone == "resnet_v1_50":
|
||||
from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == "faster_rcnn":
|
||||
return Faster_Rcnn_Resnet(config=config)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,3 @@
|
|||
Cython
|
||||
pycocotools
|
||||
mmcv==0.2.14
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run test_cross_silo_fasterrcnn.py case")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=1)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
dataset_path = args.dataset_path
|
||||
|
||||
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
|
||||
cmd_sched += "python ${self_path}/../test_fl_fasterrcnn.py"
|
||||
cmd_sched += " --device_target=" + device_target
|
||||
cmd_sched += " --server_mode=" + server_mode
|
||||
cmd_sched += " --ms_role=MS_SCHED"
|
||||
cmd_sched += " --worker_num=" + str(worker_num)
|
||||
cmd_sched += " --server_num=" + str(server_num)
|
||||
cmd_sched += " --config_file_path=" + str(config_file_path)
|
||||
cmd_sched += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_sched += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
|
||||
cmd_sched += " --dataset_path=" + str(dataset_path)
|
||||
cmd_sched += " --user_id=" + str(0)
|
||||
cmd_sched += " > scheduler.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_sched])
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run test_cross_silo_fasterrcnn.py case")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=1)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
local_server_num = args.local_server_num
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
|
||||
dataset_path = args.dataset_path
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
||||
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
|
||||
|
||||
for i in range(local_server_num):
|
||||
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
|
||||
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
|
||||
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
|
||||
cmd_server += "python ${self_path}/../test_fl_fasterrcnn.py"
|
||||
cmd_server += " --device_target=" + device_target
|
||||
cmd_server += " --server_mode=" + server_mode
|
||||
cmd_server += " --ms_role=MS_SERVER"
|
||||
cmd_server += " --worker_num=" + str(worker_num)
|
||||
cmd_server += " --server_num=" + str(server_num)
|
||||
cmd_server += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_server += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
|
||||
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
|
||||
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
|
||||
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
|
||||
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
|
||||
cmd_server += " --fl_name=" + fl_name
|
||||
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_server += " --config_file_path=" + str(config_file_path)
|
||||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --dataset_path=" + str(dataset_path)
|
||||
cmd_server += " --user_id=" + str(0)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
|
||||
import time
|
||||
time.sleep(0.3)
|
||||
subprocess.call(['bash', '-c', cmd_server])
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run test_cross_silo_fasterrcnn.py case")
|
||||
parser.add_argument("--device_target", type=str, default="GPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=1)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--local_worker_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
||||
local_worker_num = args.local_worker_num
|
||||
config_file_path = args.config_file_path
|
||||
dataset_path = args.dataset_path
|
||||
|
||||
if local_worker_num == -1:
|
||||
local_worker_num = worker_num
|
||||
|
||||
assert local_worker_num <= worker_num, "The local worker number should not be bigger than total worker number."
|
||||
for i in range(local_worker_num):
|
||||
cmd_worker = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||
cmd_worker += "rm -rf ${execute_path}/worker_" + str(i) + "/ &&"
|
||||
cmd_worker += "mkdir ${execute_path}/worker_" + str(i) + "/ &&"
|
||||
cmd_worker += "cd ${execute_path}/worker_" + str(i) + "/ || exit && export GLOG_v=1 &&"
|
||||
cmd_worker += "export CUDA_VISIBLE_DEVICES=" + str(i+4) + "&&"
|
||||
cmd_worker += "python ${self_path}/../test_fl_fasterrcnn.py"
|
||||
cmd_worker += " --device_target=" + device_target
|
||||
cmd_worker += " --server_mode=" + server_mode
|
||||
cmd_worker += " --ms_role=MS_WORKER"
|
||||
cmd_worker += " --worker_num=" + str(worker_num)
|
||||
cmd_worker += " --server_num=" + str(server_num)
|
||||
cmd_worker += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_worker += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_worker += " --config_file_path=" + str(config_file_path)
|
||||
cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_worker += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
|
||||
cmd_worker += " --dataset_path=" + str(dataset_path)
|
||||
cmd_worker += " --user_id=" + str(i)
|
||||
cmd_worker += " > worker.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_worker])
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn Init."""
|
||||
|
||||
from .resnet import ResNetFea, ResidualBlockUsing
|
||||
from .resnet50v1 import ResidualBlockUsing_V1
|
||||
from .bbox_assign_sample import BboxAssignSample
|
||||
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
|
||||
from .fpn_neck import FeatPyramidNeck
|
||||
from .proposal_generator import Proposal
|
||||
from .rcnn import Rcnn
|
||||
from .rpn import RPN
|
||||
from .roi_align import SingleRoIExtractor
|
||||
from .anchor_generator import AnchorGenerator
|
||||
|
||||
__all__ = [
|
||||
"ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn",
|
||||
"FeatPyramidNeck", "Proposal", "Rcnn",
|
||||
"RPN", "SingleRoIExtractor", "AnchorGenerator", "ResidualBlockUsing", "ResidualBlockUsing_V1"
|
||||
]
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn anchor generator."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
class AnchorGenerator():
|
||||
"""Anchor generator for FasterRcnn."""
|
||||
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
|
||||
"""Anchor generator init method."""
|
||||
self.base_size = base_size
|
||||
self.scales = np.array(scales)
|
||||
self.ratios = np.array(ratios)
|
||||
self.scale_major = scale_major
|
||||
self.ctr = ctr
|
||||
self.base_anchors = self.gen_base_anchors()
|
||||
|
||||
def gen_base_anchors(self):
|
||||
"""Generate a single anchor."""
|
||||
w = self.base_size
|
||||
h = self.base_size
|
||||
if self.ctr is None:
|
||||
x_ctr = 0.5 * (w - 1)
|
||||
y_ctr = 0.5 * (h - 1)
|
||||
else:
|
||||
x_ctr, y_ctr = self.ctr
|
||||
|
||||
h_ratios = np.sqrt(self.ratios)
|
||||
w_ratios = 1 / h_ratios
|
||||
if self.scale_major:
|
||||
ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1)
|
||||
hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1)
|
||||
else:
|
||||
ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1)
|
||||
hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1)
|
||||
|
||||
base_anchors = np.stack(
|
||||
[
|
||||
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
|
||||
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
|
||||
],
|
||||
axis=-1).round()
|
||||
|
||||
return base_anchors
|
||||
|
||||
def _meshgrid(self, x, y, row_major=True):
|
||||
"""Generate grid."""
|
||||
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
|
||||
yy = np.repeat(y, len(x))
|
||||
if row_major:
|
||||
return xx, yy
|
||||
|
||||
return yy, xx
|
||||
|
||||
def grid_anchors(self, featmap_size, stride=16):
|
||||
"""Generate anchor list."""
|
||||
base_anchors = self.base_anchors
|
||||
|
||||
feat_h, feat_w = featmap_size
|
||||
shift_x = np.arange(0, feat_w) * stride
|
||||
shift_y = np.arange(0, feat_h) * stride
|
||||
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
||||
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
|
||||
shifts = shifts.astype(base_anchors.dtype)
|
||||
# first feat_w elements correspond to the first row of shifts
|
||||
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
|
||||
# shifted anchors (K, A, 4), reshape to (K*A, 4)
|
||||
|
||||
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
|
||||
all_anchors = all_anchors.reshape(-1, 4)
|
||||
|
||||
return all_anchors
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn positive and negative sample screening for RPN."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class BboxAssignSample(nn.Cell):
|
||||
"""
|
||||
Bbox assigner and sampler definition.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_bboxes (int): The anchor nums.
|
||||
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
|
||||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
|
||||
labels: label for every bboxes, (batch_size, num_bboxes, 1)
|
||||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
|
||||
|
||||
Examples:
|
||||
BboxAssignSample(config, 2, 1024, True)
|
||||
"""
|
||||
|
||||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
|
||||
super(BboxAssignSample, self).__init__()
|
||||
cfg = config
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, self.ms_type)
|
||||
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, self.ms_type)
|
||||
self.min_pos_iou = Tensor(cfg.min_pos_iou, self.ms_type)
|
||||
self.zero_thr = Tensor(0.0, self.ms_type)
|
||||
|
||||
self.num_bboxes = num_bboxes
|
||||
self.num_gts = cfg.num_gts
|
||||
self.num_expected_pos = cfg.num_expected_pos
|
||||
self.num_expected_neg = cfg.num_expected_neg
|
||||
self.add_gt_as_proposals = add_gt_as_proposals
|
||||
|
||||
if self.add_gt_as_proposals:
|
||||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
|
||||
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.max_gt = P.ArgMaxWithValue(axis=0)
|
||||
self.max_anchor = P.ArgMaxWithValue(axis=1)
|
||||
self.sum_inds = P.ReduceSum()
|
||||
self.iou = P.IOU()
|
||||
self.greaterequal = P.GreaterEqual()
|
||||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.less = P.Less()
|
||||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
|
||||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
|
||||
self.reshape = P.Reshape()
|
||||
self.equal = P.Equal()
|
||||
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
|
||||
self.scatterNdUpdate = P.ScatterNdUpdate()
|
||||
self.scatterNd = P.ScatterNd()
|
||||
self.logicalnot = P.LogicalNot()
|
||||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
|
||||
self.assigned_gt_inds = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
|
||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
|
||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
|
||||
|
||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
|
||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype))
|
||||
self.check_gt_one = Tensor(np.full((self.num_gts, 4), -1, dtype=self.dtype))
|
||||
self.check_anchor_two = Tensor(np.full((self.num_bboxes, 4), -2, dtype=self.dtype))
|
||||
|
||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
|
||||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
|
||||
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
|
||||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
|
||||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
|
||||
|
||||
overlaps = self.iou(bboxes, gt_bboxes_i)
|
||||
|
||||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
|
||||
_, max_overlaps_w_ac = self.max_anchor(overlaps)
|
||||
|
||||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
|
||||
self.less(max_overlaps_w_gt, self.neg_iou_thr))
|
||||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
|
||||
|
||||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
|
||||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
|
||||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
|
||||
assigned_gt_inds4 = assigned_gt_inds3
|
||||
for j in range(self.num_gts):
|
||||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
|
||||
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
|
||||
|
||||
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
|
||||
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
|
||||
|
||||
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
|
||||
|
||||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
|
||||
|
||||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
|
||||
|
||||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), self.ms_type)
|
||||
pos_check_valid = self.sum_inds(pos_check_valid, -1)
|
||||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
|
||||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
|
||||
|
||||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
|
||||
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
|
||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
|
||||
|
||||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
|
||||
|
||||
num_pos = self.cast(self.logicalnot(valid_pos_index), self.ms_type)
|
||||
num_pos = self.sum_inds(num_pos, -1)
|
||||
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
|
||||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
|
||||
|
||||
pos_bboxes_ = self.gatherND(bboxes, pos_index)
|
||||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
|
||||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
|
||||
|
||||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
|
||||
|
||||
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
|
||||
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
|
||||
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
|
||||
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
|
||||
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
|
||||
total_index = self.concat((pos_index, neg_index))
|
||||
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
|
||||
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
|
||||
|
||||
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
|
||||
labels_total, self.cast(label_weights_total, mstype.bool_)
|
|
@ -0,0 +1,196 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn tpositive and negative sample screening for Rcnn."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
class BboxAssignSampleForRcnn(nn.Cell):
|
||||
"""
|
||||
Bbox assigner and sampler definition.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_bboxes (int): The anchor nums.
|
||||
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
|
||||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
|
||||
labels: label for every bboxes, (batch_size, num_bboxes, 1)
|
||||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
|
||||
|
||||
Examples:
|
||||
BboxAssignSampleForRcnn(config, 2, 1024, True)
|
||||
"""
|
||||
|
||||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
|
||||
super(BboxAssignSampleForRcnn, self).__init__()
|
||||
cfg = config
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.batch_size = batch_size
|
||||
self.neg_iou_thr = cfg.neg_iou_thr_stage2
|
||||
self.pos_iou_thr = cfg.pos_iou_thr_stage2
|
||||
self.min_pos_iou = cfg.min_pos_iou_stage2
|
||||
self.num_gts = cfg.num_gts
|
||||
self.num_bboxes = num_bboxes
|
||||
self.num_expected_pos = cfg.num_expected_pos_stage2
|
||||
self.num_expected_neg = cfg.num_expected_neg_stage2
|
||||
self.num_expected_total = cfg.num_expected_total_stage2
|
||||
|
||||
self.add_gt_as_proposals = add_gt_as_proposals
|
||||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32))
|
||||
self.add_gt_as_proposals_valid = Tensor(np.full(self.num_gts, self.add_gt_as_proposals, dtype=np.int32))
|
||||
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.max_gt = P.ArgMaxWithValue(axis=0)
|
||||
self.max_anchor = P.ArgMaxWithValue(axis=1)
|
||||
self.sum_inds = P.ReduceSum()
|
||||
self.iou = P.IOU()
|
||||
self.greaterequal = P.GreaterEqual()
|
||||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.less = P.Less()
|
||||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
|
||||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
|
||||
self.reshape = P.Reshape()
|
||||
self.equal = P.Equal()
|
||||
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2))
|
||||
self.concat_axis1 = P.Concat(axis=1)
|
||||
self.logicalnot = P.LogicalNot()
|
||||
self.tile = P.Tile()
|
||||
|
||||
# Check
|
||||
self.check_gt_one = Tensor(np.full((self.num_gts, 4), -1, dtype=self.dtype))
|
||||
self.check_anchor_two = Tensor(np.full((self.num_bboxes, 4), -2, dtype=self.dtype))
|
||||
|
||||
# Init tensor
|
||||
self.assigned_gt_inds = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
|
||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
|
||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
|
||||
|
||||
self.gt_ignores = Tensor(np.full(self.num_gts, -1, dtype=np.int32))
|
||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype))
|
||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
|
||||
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=self.dtype))
|
||||
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
|
||||
|
||||
self.reshape_shape_pos = (self.num_expected_pos, 1)
|
||||
self.reshape_shape_neg = (self.num_expected_neg, 1)
|
||||
|
||||
self.scalar_zero = Tensor(0.0, dtype=self.ms_type)
|
||||
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=self.ms_type)
|
||||
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=self.ms_type)
|
||||
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=self.ms_type)
|
||||
|
||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
|
||||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
|
||||
(self.num_gts, 1)), (1, 4)), mstype.bool_), \
|
||||
gt_bboxes_i, self.check_gt_one)
|
||||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
|
||||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), \
|
||||
bboxes, self.check_anchor_two)
|
||||
|
||||
overlaps = self.iou(bboxes, gt_bboxes_i)
|
||||
|
||||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
|
||||
_, max_overlaps_w_ac = self.max_anchor(overlaps)
|
||||
|
||||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt,
|
||||
self.scalar_zero),
|
||||
self.less(max_overlaps_w_gt,
|
||||
self.scalar_neg_iou_thr))
|
||||
|
||||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
|
||||
|
||||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr)
|
||||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
|
||||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
|
||||
|
||||
for j in range(self.num_gts):
|
||||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
|
||||
overlaps_w_ac_j = overlaps[j:j+1:1, ::]
|
||||
temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou)
|
||||
temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j))
|
||||
pos_mask_j = self.logicaland(temp1, temp2)
|
||||
assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3)
|
||||
|
||||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores)
|
||||
|
||||
bboxes = self.concat((gt_bboxes_i, bboxes))
|
||||
label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores)
|
||||
label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid
|
||||
assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5))
|
||||
|
||||
# Get pos index
|
||||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
|
||||
|
||||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), self.ms_type)
|
||||
pos_check_valid = self.sum_inds(pos_check_valid, -1)
|
||||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
|
||||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
|
||||
|
||||
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), self.ms_type), -1)
|
||||
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
|
||||
pos_index = self.reshape(pos_index, self.reshape_shape_pos)
|
||||
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
|
||||
pos_index = pos_index * valid_pos_index
|
||||
|
||||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
|
||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
|
||||
pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index
|
||||
|
||||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
|
||||
|
||||
# Get neg index
|
||||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
|
||||
|
||||
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
|
||||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
|
||||
neg_index = self.reshape(neg_index, self.reshape_shape_neg)
|
||||
|
||||
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
|
||||
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
|
||||
neg_index = neg_index * valid_neg_index
|
||||
|
||||
pos_bboxes_ = self.gatherND(bboxes, pos_index)
|
||||
|
||||
neg_bboxes_ = self.gatherND(bboxes, neg_index)
|
||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
|
||||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
|
||||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
|
||||
|
||||
total_bboxes = self.concat((pos_bboxes_, neg_bboxes_))
|
||||
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask))
|
||||
total_labels = self.concat((pos_gt_labels, self.labels_neg_mask))
|
||||
|
||||
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
|
||||
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
|
||||
total_mask = self.concat((valid_pos_index, valid_neg_index))
|
||||
|
||||
return total_bboxes, total_deltas, total_labels, total_mask
|
|
@ -0,0 +1,488 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn based on ResNet."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from .resnet import ResNetFea, ResidualBlockUsing
|
||||
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
|
||||
from .fpn_neck import FeatPyramidNeck
|
||||
from .proposal_generator import Proposal
|
||||
from .rcnn import Rcnn
|
||||
from .rpn import RPN
|
||||
from .roi_align import SingleRoIExtractor
|
||||
from .anchor_generator import AnchorGenerator
|
||||
|
||||
|
||||
class Faster_Rcnn_Resnet(nn.Cell):
|
||||
"""
|
||||
FasterRcnn Network.
|
||||
|
||||
Note:
|
||||
backbone = resnet
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor.
|
||||
rpn_loss: Scalar, Total loss of RPN subnet.
|
||||
rcnn_loss: Scalar, Total loss of RCNN subnet.
|
||||
rpn_cls_loss: Scalar, Classification loss of RPN subnet.
|
||||
rpn_reg_loss: Scalar, Regression loss of RPN subnet.
|
||||
rcnn_cls_loss: Scalar, Classification loss of RCNN subnet.
|
||||
rcnn_reg_loss: Scalar, Regression loss of RCNN subnet.
|
||||
|
||||
Examples:
|
||||
net = Faster_Rcnn_Resnet()
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(Faster_Rcnn_Resnet, self).__init__()
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.train_batch_size = config.batch_size
|
||||
self.num_classes = config.num_classes
|
||||
self.anchor_scales = config.anchor_scales
|
||||
self.anchor_ratios = config.anchor_ratios
|
||||
self.anchor_strides = config.anchor_strides
|
||||
self.target_means = tuple(config.rcnn_target_means)
|
||||
self.target_stds = tuple(config.rcnn_target_stds)
|
||||
|
||||
# Anchor generator
|
||||
anchor_base_sizes = None
|
||||
self.anchor_base_sizes = list(
|
||||
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
|
||||
|
||||
self.anchor_generators = []
|
||||
for anchor_base in self.anchor_base_sizes:
|
||||
self.anchor_generators.append(
|
||||
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))
|
||||
|
||||
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
|
||||
|
||||
featmap_sizes = config.feature_shapes
|
||||
assert len(featmap_sizes) == len(self.anchor_generators)
|
||||
|
||||
self.anchor_list = self.get_anchors(featmap_sizes)
|
||||
|
||||
# Backbone resnet
|
||||
self.backbone = ResNetFea(ResidualBlockUsing,
|
||||
config.resnet_block,
|
||||
config.resnet_in_channels,
|
||||
config.resnet_out_channels,
|
||||
False)
|
||||
|
||||
# Fpn
|
||||
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels,
|
||||
config.fpn_out_channels,
|
||||
config.fpn_num_outs)
|
||||
|
||||
# Rpn and rpn loss
|
||||
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
|
||||
self.rpn_with_loss = RPN(config,
|
||||
self.train_batch_size,
|
||||
config.rpn_in_channels,
|
||||
config.rpn_feat_channels,
|
||||
config.num_anchors,
|
||||
config.rpn_cls_out_channels)
|
||||
|
||||
# Proposal
|
||||
self.proposal_generator = Proposal(config,
|
||||
self.train_batch_size,
|
||||
config.activate_num_classes,
|
||||
config.use_sigmoid_cls)
|
||||
self.proposal_generator.set_train_local(config, True)
|
||||
self.proposal_generator_test = Proposal(config,
|
||||
config.test_batch_size,
|
||||
config.activate_num_classes,
|
||||
config.use_sigmoid_cls)
|
||||
self.proposal_generator_test.set_train_local(config, False)
|
||||
|
||||
# Assign and sampler stage two
|
||||
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size,
|
||||
config.num_bboxes_stage2, True)
|
||||
self.decode = P.BoundingBoxDecode(max_shape=(config.img_height, config.img_width), means=self.target_means, \
|
||||
stds=self.target_stds)
|
||||
# Roi
|
||||
self.roi_init(config)
|
||||
|
||||
# Rcnn
|
||||
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer.out_size * config.roi_layer.out_size,
|
||||
self.train_batch_size, self.num_classes)
|
||||
|
||||
# Op declare
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.concat_1 = P.Concat(axis=1)
|
||||
self.concat_2 = P.Concat(axis=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.select = P.Select()
|
||||
self.greater = P.Greater()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
# Improve speed
|
||||
self.concat_start = min(self.num_classes - 2, 55)
|
||||
self.concat_end = (self.num_classes - 1)
|
||||
|
||||
# Test mode
|
||||
self.test_mode_init(config)
|
||||
|
||||
# Init tensor
|
||||
self.init_tensor(config)
|
||||
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
|
||||
def roi_init(self, config):
|
||||
"""
|
||||
Initialize roi from the config file
|
||||
|
||||
Args:
|
||||
config (file): config file.
|
||||
roi_layer (dict): Numbers of block in different layers.
|
||||
roi_align_out_channels (int): Out channel in each layer.
|
||||
config.roi_align_featmap_strides (list): featmap_strides in each layer.
|
||||
roi_align_finest_scale (int): finest_scale in roi.
|
||||
|
||||
Examples:
|
||||
self.roi_init(config)
|
||||
"""
|
||||
self.roi_align = SingleRoIExtractor(config,
|
||||
config.roi_layer,
|
||||
config.roi_align_out_channels,
|
||||
config.roi_align_featmap_strides,
|
||||
self.train_batch_size,
|
||||
config.roi_align_finest_scale)
|
||||
self.roi_align.set_train_local(config, True)
|
||||
self.roi_align_test = SingleRoIExtractor(config,
|
||||
config.roi_layer,
|
||||
config.roi_align_out_channels,
|
||||
config.roi_align_featmap_strides,
|
||||
1,
|
||||
config.roi_align_finest_scale)
|
||||
self.roi_align_test.set_train_local(config, False)
|
||||
|
||||
def test_mode_init(self, config):
|
||||
"""
|
||||
Initialize test_mode from the config file.
|
||||
|
||||
Args:
|
||||
config (file): config file.
|
||||
test_batch_size (int): Size of test batch.
|
||||
rpn_max_num (int): max num of rpn.
|
||||
test_score_thresh (float): threshold of test score.
|
||||
test_iou_thr (float): threshold of test iou.
|
||||
|
||||
Examples:
|
||||
self.test_mode_init(config)
|
||||
"""
|
||||
self.test_batch_size = config.test_batch_size
|
||||
self.split = P.Split(axis=0, output_num=self.test_batch_size)
|
||||
self.split_shape = P.Split(axis=0, output_num=4)
|
||||
self.split_scores = P.Split(axis=1, output_num=self.num_classes)
|
||||
self.split_cls = P.Split(axis=0, output_num=self.num_classes-1)
|
||||
self.tile = P.Tile()
|
||||
self.gather = P.GatherNd()
|
||||
|
||||
self.rpn_max_num = config.rpn_max_num
|
||||
|
||||
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(self.dtype))
|
||||
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
|
||||
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
|
||||
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
|
||||
self.ones_mask, self.zeros_mask), axis=1))
|
||||
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
|
||||
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))
|
||||
|
||||
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_score_thr)
|
||||
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * 0)
|
||||
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(self.dtype) * -1)
|
||||
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_iou_thr)
|
||||
self.test_max_per_img = config.test_max_per_img
|
||||
self.nms_test = P.NMSWithMask(config.test_iou_thr)
|
||||
self.softmax = P.Softmax(axis=1)
|
||||
self.logicand = P.LogicalAnd()
|
||||
self.oneslike = P.OnesLike()
|
||||
self.test_topk = P.TopK(sorted=True)
|
||||
self.test_num_proposal = self.test_batch_size * self.rpn_max_num
|
||||
|
||||
def init_tensor(self, config):
|
||||
|
||||
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
|
||||
dtype=self.dtype) for i in range(self.train_batch_size)]
|
||||
|
||||
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=self.dtype) \
|
||||
for i in range(self.test_batch_size)]
|
||||
|
||||
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
|
||||
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))
|
||||
|
||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
|
||||
"""
|
||||
construct the FasterRcnn Network.
|
||||
|
||||
Args:
|
||||
img_data: input image data.
|
||||
img_metas: meta label of img.
|
||||
gt_bboxes (Tensor): get the value of bboxes.
|
||||
gt_labels (Tensor): get the value of labels.
|
||||
gt_valids (Tensor): get the valid part of bboxes.
|
||||
|
||||
Returns:
|
||||
Tuple,tuple of output tensor
|
||||
"""
|
||||
x = self.backbone(img_data)
|
||||
x = self.fpn_ncek(x)
|
||||
|
||||
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x,
|
||||
img_metas,
|
||||
self.anchor_list,
|
||||
gt_bboxes,
|
||||
self.gt_labels_stage1,
|
||||
gt_valids)
|
||||
|
||||
if self.training:
|
||||
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list)
|
||||
else:
|
||||
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
|
||||
|
||||
gt_labels = self.cast(gt_labels, mstype.int32)
|
||||
gt_valids = self.cast(gt_valids, mstype.int32)
|
||||
bboxes_tuple = ()
|
||||
deltas_tuple = ()
|
||||
labels_tuple = ()
|
||||
mask_tuple = ()
|
||||
if self.training:
|
||||
for i in range(self.train_batch_size):
|
||||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
|
||||
|
||||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
|
||||
gt_labels_i = self.cast(gt_labels_i, mstype.uint8)
|
||||
|
||||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
|
||||
gt_valids_i = self.cast(gt_valids_i, mstype.bool_)
|
||||
|
||||
bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i,
|
||||
gt_labels_i,
|
||||
proposal_mask[i],
|
||||
proposal[i][::, 0:4:1],
|
||||
gt_valids_i)
|
||||
bboxes_tuple += (bboxes,)
|
||||
deltas_tuple += (deltas,)
|
||||
labels_tuple += (labels,)
|
||||
mask_tuple += (mask,)
|
||||
|
||||
bbox_targets = self.concat(deltas_tuple)
|
||||
rcnn_labels = self.concat(labels_tuple)
|
||||
bbox_targets = F.stop_gradient(bbox_targets)
|
||||
rcnn_labels = F.stop_gradient(rcnn_labels)
|
||||
rcnn_labels = self.cast(rcnn_labels, mstype.int32)
|
||||
else:
|
||||
mask_tuple += proposal_mask
|
||||
bbox_targets = proposal_mask
|
||||
rcnn_labels = proposal_mask
|
||||
for p_i in proposal:
|
||||
bboxes_tuple += (p_i[::, 0:4:1],)
|
||||
|
||||
if self.training:
|
||||
if self.train_batch_size > 1:
|
||||
bboxes_all = self.concat(bboxes_tuple)
|
||||
else:
|
||||
bboxes_all = bboxes_tuple[0]
|
||||
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
|
||||
else:
|
||||
if self.test_batch_size > 1:
|
||||
bboxes_all = self.concat(bboxes_tuple)
|
||||
else:
|
||||
bboxes_all = bboxes_tuple[0]
|
||||
if self.device_type == "Ascend":
|
||||
bboxes_all = self.cast(bboxes_all, mstype.float16)
|
||||
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
|
||||
|
||||
rois = self.cast(rois, mstype.float32)
|
||||
rois = F.stop_gradient(rois)
|
||||
|
||||
if self.training:
|
||||
roi_feats = self.roi_align(rois,
|
||||
self.cast(x[0], mstype.float32),
|
||||
self.cast(x[1], mstype.float32),
|
||||
self.cast(x[2], mstype.float32),
|
||||
self.cast(x[3], mstype.float32))
|
||||
else:
|
||||
roi_feats = self.roi_align_test(rois,
|
||||
self.cast(x[0], mstype.float32),
|
||||
self.cast(x[1], mstype.float32),
|
||||
self.cast(x[2], mstype.float32),
|
||||
self.cast(x[3], mstype.float32))
|
||||
|
||||
roi_feats = self.cast(roi_feats, self.ms_type)
|
||||
rcnn_masks = self.concat(mask_tuple)
|
||||
rcnn_masks = F.stop_gradient(rcnn_masks)
|
||||
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
|
||||
rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats,
|
||||
bbox_targets,
|
||||
rcnn_labels,
|
||||
rcnn_mask_squeeze)
|
||||
|
||||
output = ()
|
||||
if self.training:
|
||||
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss)
|
||||
else:
|
||||
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas)
|
||||
|
||||
return output
|
||||
|
||||
def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas):
|
||||
"""Get the actual detection box."""
|
||||
scores = self.softmax(cls_logits)
|
||||
|
||||
boxes_all = ()
|
||||
for i in range(self.num_classes):
|
||||
k = i * 4
|
||||
reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1])
|
||||
out_boxes_i = self.decode(rois, reg_logits_i)
|
||||
boxes_all += (out_boxes_i,)
|
||||
|
||||
img_metas_all = self.split(img_metas)
|
||||
scores_all = self.split(scores)
|
||||
mask_all = self.split(self.cast(mask_logits, mstype.int32))
|
||||
|
||||
boxes_all_with_batchsize = ()
|
||||
for i in range(self.test_batch_size):
|
||||
scale = self.split_shape(self.squeeze(img_metas_all[i]))
|
||||
scale_h = scale[2]
|
||||
scale_w = scale[3]
|
||||
boxes_tuple = ()
|
||||
for j in range(self.num_classes):
|
||||
boxes_tmp = self.split(boxes_all[j])
|
||||
out_boxes_h = boxes_tmp[i] / scale_h
|
||||
out_boxes_w = boxes_tmp[i] / scale_w
|
||||
boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),)
|
||||
boxes_all_with_batchsize += (boxes_tuple,)
|
||||
|
||||
output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all)
|
||||
|
||||
return output
|
||||
|
||||
def multiclass_nms(self, boxes_all, scores_all, mask_all):
|
||||
"""Multiscale postprocessing."""
|
||||
all_bboxes = ()
|
||||
all_labels = ()
|
||||
all_masks = ()
|
||||
|
||||
for i in range(self.test_batch_size):
|
||||
bboxes = boxes_all[i]
|
||||
scores = scores_all[i]
|
||||
masks = self.cast(mask_all[i], mstype.bool_)
|
||||
|
||||
res_boxes_tuple = ()
|
||||
res_labels_tuple = ()
|
||||
res_masks_tuple = ()
|
||||
|
||||
for j in range(self.num_classes - 1):
|
||||
k = j + 1
|
||||
_cls_scores = scores[::, k:k + 1:1]
|
||||
_bboxes = self.squeeze(bboxes[k])
|
||||
_mask_o = self.reshape(masks, (self.rpn_max_num, 1))
|
||||
|
||||
cls_mask = self.greater(_cls_scores, self.test_score_thresh)
|
||||
_mask = self.logicand(_mask_o, cls_mask)
|
||||
|
||||
_reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_)
|
||||
|
||||
_bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros)
|
||||
_cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros)
|
||||
__cls_scores = self.squeeze(_cls_scores)
|
||||
scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num)
|
||||
topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1))
|
||||
scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1))
|
||||
_bboxes_sorted = self.gather(_bboxes, topk_inds)
|
||||
_mask_sorted = self.gather(_mask, topk_inds)
|
||||
|
||||
scores_sorted = self.tile(scores_sorted, (1, 4))
|
||||
cls_dets = self.concat_1((_bboxes_sorted, scores_sorted))
|
||||
cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5))
|
||||
|
||||
cls_dets, _index, _mask_nms = self.nms_test(cls_dets)
|
||||
_index = self.reshape(_index, (self.rpn_max_num, 1))
|
||||
_mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1))
|
||||
|
||||
_mask_n = self.gather(_mask_sorted, _index)
|
||||
|
||||
_mask_n = self.logicand(_mask_n, _mask_nms)
|
||||
cls_labels = self.oneslike(_index) * j
|
||||
res_boxes_tuple += (cls_dets,)
|
||||
res_labels_tuple += (cls_labels,)
|
||||
res_masks_tuple += (_mask_n,)
|
||||
|
||||
res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start])
|
||||
res_labels_start = self.concat(res_labels_tuple[:self.concat_start])
|
||||
res_masks_start = self.concat(res_masks_tuple[:self.concat_start])
|
||||
|
||||
res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end])
|
||||
res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end])
|
||||
res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end])
|
||||
|
||||
res_boxes = self.concat((res_boxes_start, res_boxes_end))
|
||||
res_labels = self.concat((res_labels_start, res_labels_end))
|
||||
res_masks = self.concat((res_masks_start, res_masks_end))
|
||||
|
||||
reshape_size = (self.num_classes - 1) * self.rpn_max_num
|
||||
res_boxes = self.reshape(res_boxes, (1, reshape_size, 5))
|
||||
res_labels = self.reshape(res_labels, (1, reshape_size, 1))
|
||||
res_masks = self.reshape(res_masks, (1, reshape_size, 1))
|
||||
|
||||
all_bboxes += (res_boxes,)
|
||||
all_labels += (res_labels,)
|
||||
all_masks += (res_masks,)
|
||||
|
||||
all_bboxes = self.concat(all_bboxes)
|
||||
all_labels = self.concat(all_labels)
|
||||
all_masks = self.concat(all_masks)
|
||||
return all_bboxes, all_labels, all_masks
|
||||
|
||||
def get_anchors(self, featmap_sizes):
|
||||
"""Get anchors according to feature map sizes.
|
||||
|
||||
Args:
|
||||
featmap_sizes (list[tuple]): Multi-level feature map sizes.
|
||||
img_metas (list[dict]): Image meta info.
|
||||
|
||||
Returns:
|
||||
tuple: anchors of each image, valid flags of each image
|
||||
"""
|
||||
num_levels = len(featmap_sizes)
|
||||
|
||||
# since feature map sizes of all images are the same, we only compute
|
||||
# anchors for one time
|
||||
multi_level_anchors = ()
|
||||
for i in range(num_levels):
|
||||
anchors = self.anchor_generators[i].grid_anchors(
|
||||
featmap_sizes[i], self.anchor_strides[i])
|
||||
multi_level_anchors += (Tensor(anchors.astype(self.dtype)),)
|
||||
|
||||
return multi_level_anchors
|
||||
|
||||
class FasterRcnn_Infer(nn.Cell):
|
||||
def __init__(self, config):
|
||||
super(FasterRcnn_Infer, self).__init__()
|
||||
self.network = Faster_Rcnn_Resnet(config)
|
||||
self.network.set_train(False)
|
||||
|
||||
def construct(self, img_data, img_metas):
|
||||
output = self.network(img_data, img_metas, None, None, None)
|
||||
return output
|
|
@ -0,0 +1,488 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn based on ResNet50v1.0."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from .resnet import ResNetFea
|
||||
from .resnet50v1 import ResidualBlockUsing_V1
|
||||
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
|
||||
from .fpn_neck import FeatPyramidNeck
|
||||
from .proposal_generator import Proposal
|
||||
from .rcnn import Rcnn
|
||||
from .rpn import RPN
|
||||
from .roi_align import SingleRoIExtractor
|
||||
from .anchor_generator import AnchorGenerator
|
||||
|
||||
|
||||
class Faster_Rcnn_Resnet(nn.Cell):
|
||||
"""
|
||||
FasterRcnn Network.
|
||||
|
||||
Note:
|
||||
backbone = resnet
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor.
|
||||
rpn_loss: Scalar, Total loss of RPN subnet.
|
||||
rcnn_loss: Scalar, Total loss of RCNN subnet.
|
||||
rpn_cls_loss: Scalar, Classification loss of RPN subnet.
|
||||
rpn_reg_loss: Scalar, Regression loss of RPN subnet.
|
||||
rcnn_cls_loss: Scalar, Classification loss of RCNN subnet.
|
||||
rcnn_reg_loss: Scalar, Regression loss of RCNN subnet.
|
||||
|
||||
Examples:
|
||||
net = Faster_Rcnn_Resnet()
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(Faster_Rcnn_Resnet, self).__init__()
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.train_batch_size = config.batch_size
|
||||
self.num_classes = config.num_classes
|
||||
self.anchor_scales = config.anchor_scales
|
||||
self.anchor_ratios = config.anchor_ratios
|
||||
self.anchor_strides = config.anchor_strides
|
||||
self.target_means = tuple(config.rcnn_target_means)
|
||||
self.target_stds = tuple(config.rcnn_target_stds)
|
||||
|
||||
# Anchor generator
|
||||
anchor_base_sizes = None
|
||||
self.anchor_base_sizes = list(
|
||||
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
|
||||
|
||||
self.anchor_generators = []
|
||||
for anchor_base in self.anchor_base_sizes:
|
||||
self.anchor_generators.append(
|
||||
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))
|
||||
|
||||
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
|
||||
|
||||
featmap_sizes = config.feature_shapes
|
||||
assert len(featmap_sizes) == len(self.anchor_generators)
|
||||
|
||||
self.anchor_list = self.get_anchors(featmap_sizes)
|
||||
|
||||
# Backbone resnet
|
||||
self.backbone = ResNetFea(ResidualBlockUsing_V1,
|
||||
config.resnet_block,
|
||||
config.resnet_in_channels,
|
||||
config.resnet_out_channels,
|
||||
False)
|
||||
|
||||
# Fpn
|
||||
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels,
|
||||
config.fpn_out_channels,
|
||||
config.fpn_num_outs)
|
||||
|
||||
# Rpn and rpn loss
|
||||
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
|
||||
self.rpn_with_loss = RPN(config,
|
||||
self.train_batch_size,
|
||||
config.rpn_in_channels,
|
||||
config.rpn_feat_channels,
|
||||
config.num_anchors,
|
||||
config.rpn_cls_out_channels)
|
||||
|
||||
# Proposal
|
||||
self.proposal_generator = Proposal(config,
|
||||
self.train_batch_size,
|
||||
config.activate_num_classes,
|
||||
config.use_sigmoid_cls)
|
||||
self.proposal_generator.set_train_local(config, True)
|
||||
self.proposal_generator_test = Proposal(config,
|
||||
config.test_batch_size,
|
||||
config.activate_num_classes,
|
||||
config.use_sigmoid_cls)
|
||||
self.proposal_generator_test.set_train_local(config, False)
|
||||
|
||||
# Assign and sampler stage two
|
||||
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size,
|
||||
config.num_bboxes_stage2, True)
|
||||
self.decode = P.BoundingBoxDecode(max_shape=(config.img_height, config.img_width), means=self.target_means, \
|
||||
stds=self.target_stds)
|
||||
# Roi
|
||||
self.roi_init(config)
|
||||
|
||||
# Rcnn
|
||||
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer.out_size * config.roi_layer.out_size,
|
||||
self.train_batch_size, self.num_classes)
|
||||
|
||||
# Op declare
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.concat_1 = P.Concat(axis=1)
|
||||
self.concat_2 = P.Concat(axis=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.select = P.Select()
|
||||
self.greater = P.Greater()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
# Improve speed
|
||||
self.concat_start = min(self.num_classes - 2, 55)
|
||||
self.concat_end = (self.num_classes - 1)
|
||||
|
||||
# Test mode
|
||||
self.test_mode_init(config)
|
||||
|
||||
# Init tensor
|
||||
self.init_tensor(config)
|
||||
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
|
||||
def roi_init(self, config):
|
||||
"""
|
||||
Initialize roi from the config file
|
||||
|
||||
Args:
|
||||
config (file): config file.
|
||||
roi_layer (dict): Numbers of block in different layers.
|
||||
roi_align_out_channels (int): Out channel in each layer.
|
||||
config.roi_align_featmap_strides (list): featmap_strides in each layer.
|
||||
roi_align_finest_scale (int): finest_scale in roi.
|
||||
|
||||
Examples:
|
||||
self.roi_init(config)
|
||||
"""
|
||||
self.roi_align = SingleRoIExtractor(config,
|
||||
config.roi_layer,
|
||||
config.roi_align_out_channels,
|
||||
config.roi_align_featmap_strides,
|
||||
self.train_batch_size,
|
||||
config.roi_align_finest_scale)
|
||||
self.roi_align.set_train_local(config, True)
|
||||
self.roi_align_test = SingleRoIExtractor(config,
|
||||
config.roi_layer,
|
||||
config.roi_align_out_channels,
|
||||
config.roi_align_featmap_strides,
|
||||
1,
|
||||
config.roi_align_finest_scale)
|
||||
self.roi_align_test.set_train_local(config, False)
|
||||
|
||||
def test_mode_init(self, config):
|
||||
"""
|
||||
Initialize test_mode from the config file.
|
||||
|
||||
Args:
|
||||
config (file): config file.
|
||||
test_batch_size (int): Size of test batch.
|
||||
rpn_max_num (int): max num of rpn.
|
||||
test_score_thresh (float): threshold of test score.
|
||||
test_iou_thr (float): threshold of test iou.
|
||||
|
||||
Examples:
|
||||
self.test_mode_init(config)
|
||||
"""
|
||||
self.test_batch_size = config.test_batch_size
|
||||
self.split = P.Split(axis=0, output_num=self.test_batch_size)
|
||||
self.split_shape = P.Split(axis=0, output_num=4)
|
||||
self.split_scores = P.Split(axis=1, output_num=self.num_classes)
|
||||
self.split_cls = P.Split(axis=0, output_num=self.num_classes-1)
|
||||
self.tile = P.Tile()
|
||||
self.gather = P.GatherNd()
|
||||
|
||||
self.rpn_max_num = config.rpn_max_num
|
||||
|
||||
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(self.dtype))
|
||||
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
|
||||
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
|
||||
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
|
||||
self.ones_mask, self.zeros_mask), axis=1))
|
||||
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
|
||||
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))
|
||||
|
||||
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_score_thr)
|
||||
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * 0)
|
||||
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(self.dtype) * -1)
|
||||
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_iou_thr)
|
||||
self.test_max_per_img = config.test_max_per_img
|
||||
self.nms_test = P.NMSWithMask(config.test_iou_thr)
|
||||
self.softmax = P.Softmax(axis=1)
|
||||
self.logicand = P.LogicalAnd()
|
||||
self.oneslike = P.OnesLike()
|
||||
self.test_topk = P.TopK(sorted=True)
|
||||
self.test_num_proposal = self.test_batch_size * self.rpn_max_num
|
||||
|
||||
def init_tensor(self, config):
|
||||
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
|
||||
dtype=self.dtype) for i in range(self.train_batch_size)]
|
||||
|
||||
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=self.dtype) \
|
||||
for i in range(self.test_batch_size)]
|
||||
|
||||
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
|
||||
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))
|
||||
|
||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
|
||||
"""
|
||||
construct the FasterRcnn Network.
|
||||
|
||||
Args:
|
||||
img_data: input image data.
|
||||
img_metas: meta label of img.
|
||||
gt_bboxes (Tensor): get the value of bboxes.
|
||||
gt_labels (Tensor): get the value of labels.
|
||||
gt_valids (Tensor): get the valid part of bboxes.
|
||||
|
||||
Returns:
|
||||
Tuple,tuple of output tensor
|
||||
"""
|
||||
x = self.backbone(img_data)
|
||||
x = self.fpn_ncek(x)
|
||||
|
||||
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x,
|
||||
img_metas,
|
||||
self.anchor_list,
|
||||
gt_bboxes,
|
||||
self.gt_labels_stage1,
|
||||
gt_valids)
|
||||
|
||||
if self.training:
|
||||
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list)
|
||||
else:
|
||||
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
|
||||
|
||||
gt_labels = self.cast(gt_labels, mstype.int32)
|
||||
gt_valids = self.cast(gt_valids, mstype.int32)
|
||||
bboxes_tuple = ()
|
||||
deltas_tuple = ()
|
||||
labels_tuple = ()
|
||||
mask_tuple = ()
|
||||
if self.training:
|
||||
for i in range(self.train_batch_size):
|
||||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
|
||||
|
||||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
|
||||
gt_labels_i = self.cast(gt_labels_i, mstype.uint8)
|
||||
|
||||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
|
||||
gt_valids_i = self.cast(gt_valids_i, mstype.bool_)
|
||||
|
||||
bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i,
|
||||
gt_labels_i,
|
||||
proposal_mask[i],
|
||||
proposal[i][::, 0:4:1],
|
||||
gt_valids_i)
|
||||
bboxes_tuple += (bboxes,)
|
||||
deltas_tuple += (deltas,)
|
||||
labels_tuple += (labels,)
|
||||
mask_tuple += (mask,)
|
||||
|
||||
bbox_targets = self.concat(deltas_tuple)
|
||||
rcnn_labels = self.concat(labels_tuple)
|
||||
bbox_targets = F.stop_gradient(bbox_targets)
|
||||
rcnn_labels = F.stop_gradient(rcnn_labels)
|
||||
rcnn_labels = self.cast(rcnn_labels, mstype.int32)
|
||||
else:
|
||||
mask_tuple += proposal_mask
|
||||
bbox_targets = proposal_mask
|
||||
rcnn_labels = proposal_mask
|
||||
for p_i in proposal:
|
||||
bboxes_tuple += (p_i[::, 0:4:1],)
|
||||
|
||||
if self.training:
|
||||
if self.train_batch_size > 1:
|
||||
bboxes_all = self.concat(bboxes_tuple)
|
||||
else:
|
||||
bboxes_all = bboxes_tuple[0]
|
||||
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
|
||||
else:
|
||||
if self.test_batch_size > 1:
|
||||
bboxes_all = self.concat(bboxes_tuple)
|
||||
else:
|
||||
bboxes_all = bboxes_tuple[0]
|
||||
if self.device_type == "Ascend":
|
||||
bboxes_all = self.cast(bboxes_all, mstype.float16)
|
||||
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
|
||||
|
||||
rois = self.cast(rois, mstype.float32)
|
||||
rois = F.stop_gradient(rois)
|
||||
|
||||
if self.training:
|
||||
roi_feats = self.roi_align(rois,
|
||||
self.cast(x[0], mstype.float32),
|
||||
self.cast(x[1], mstype.float32),
|
||||
self.cast(x[2], mstype.float32),
|
||||
self.cast(x[3], mstype.float32))
|
||||
else:
|
||||
roi_feats = self.roi_align_test(rois,
|
||||
self.cast(x[0], mstype.float32),
|
||||
self.cast(x[1], mstype.float32),
|
||||
self.cast(x[2], mstype.float32),
|
||||
self.cast(x[3], mstype.float32))
|
||||
|
||||
roi_feats = self.cast(roi_feats, self.ms_type)
|
||||
rcnn_masks = self.concat(mask_tuple)
|
||||
rcnn_masks = F.stop_gradient(rcnn_masks)
|
||||
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
|
||||
rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats,
|
||||
bbox_targets,
|
||||
rcnn_labels,
|
||||
rcnn_mask_squeeze)
|
||||
|
||||
output = ()
|
||||
if self.training:
|
||||
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss)
|
||||
else:
|
||||
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas)
|
||||
|
||||
return output
|
||||
|
||||
def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas):
|
||||
"""Get the actual detection box."""
|
||||
scores = self.softmax(cls_logits)
|
||||
|
||||
boxes_all = ()
|
||||
for i in range(self.num_classes):
|
||||
k = i * 4
|
||||
reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1])
|
||||
out_boxes_i = self.decode(rois, reg_logits_i)
|
||||
boxes_all += (out_boxes_i,)
|
||||
|
||||
img_metas_all = self.split(img_metas)
|
||||
scores_all = self.split(scores)
|
||||
mask_all = self.split(self.cast(mask_logits, mstype.int32))
|
||||
|
||||
boxes_all_with_batchsize = ()
|
||||
for i in range(self.test_batch_size):
|
||||
scale = self.split_shape(self.squeeze(img_metas_all[i]))
|
||||
scale_h = scale[2]
|
||||
scale_w = scale[3]
|
||||
boxes_tuple = ()
|
||||
for j in range(self.num_classes):
|
||||
boxes_tmp = self.split(boxes_all[j])
|
||||
out_boxes_h = boxes_tmp[i] / scale_h
|
||||
out_boxes_w = boxes_tmp[i] / scale_w
|
||||
boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),)
|
||||
boxes_all_with_batchsize += (boxes_tuple,)
|
||||
|
||||
output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all)
|
||||
|
||||
return output
|
||||
|
||||
def multiclass_nms(self, boxes_all, scores_all, mask_all):
|
||||
"""Multiscale postprocessing."""
|
||||
all_bboxes = ()
|
||||
all_labels = ()
|
||||
all_masks = ()
|
||||
|
||||
for i in range(self.test_batch_size):
|
||||
bboxes = boxes_all[i]
|
||||
scores = scores_all[i]
|
||||
masks = self.cast(mask_all[i], mstype.bool_)
|
||||
|
||||
res_boxes_tuple = ()
|
||||
res_labels_tuple = ()
|
||||
res_masks_tuple = ()
|
||||
|
||||
for j in range(self.num_classes - 1):
|
||||
k = j + 1
|
||||
_cls_scores = scores[::, k:k + 1:1]
|
||||
_bboxes = self.squeeze(bboxes[k])
|
||||
_mask_o = self.reshape(masks, (self.rpn_max_num, 1))
|
||||
|
||||
cls_mask = self.greater(_cls_scores, self.test_score_thresh)
|
||||
_mask = self.logicand(_mask_o, cls_mask)
|
||||
|
||||
_reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_)
|
||||
|
||||
_bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros)
|
||||
_cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros)
|
||||
__cls_scores = self.squeeze(_cls_scores)
|
||||
scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num)
|
||||
topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1))
|
||||
scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1))
|
||||
_bboxes_sorted = self.gather(_bboxes, topk_inds)
|
||||
_mask_sorted = self.gather(_mask, topk_inds)
|
||||
|
||||
scores_sorted = self.tile(scores_sorted, (1, 4))
|
||||
cls_dets = self.concat_1((_bboxes_sorted, scores_sorted))
|
||||
cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5))
|
||||
|
||||
cls_dets, _index, _mask_nms = self.nms_test(cls_dets)
|
||||
_index = self.reshape(_index, (self.rpn_max_num, 1))
|
||||
_mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1))
|
||||
|
||||
_mask_n = self.gather(_mask_sorted, _index)
|
||||
|
||||
_mask_n = self.logicand(_mask_n, _mask_nms)
|
||||
cls_labels = self.oneslike(_index) * j
|
||||
res_boxes_tuple += (cls_dets,)
|
||||
res_labels_tuple += (cls_labels,)
|
||||
res_masks_tuple += (_mask_n,)
|
||||
|
||||
res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start])
|
||||
res_labels_start = self.concat(res_labels_tuple[:self.concat_start])
|
||||
res_masks_start = self.concat(res_masks_tuple[:self.concat_start])
|
||||
|
||||
res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end])
|
||||
res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end])
|
||||
res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end])
|
||||
|
||||
res_boxes = self.concat((res_boxes_start, res_boxes_end))
|
||||
res_labels = self.concat((res_labels_start, res_labels_end))
|
||||
res_masks = self.concat((res_masks_start, res_masks_end))
|
||||
|
||||
reshape_size = (self.num_classes - 1) * self.rpn_max_num
|
||||
res_boxes = self.reshape(res_boxes, (1, reshape_size, 5))
|
||||
res_labels = self.reshape(res_labels, (1, reshape_size, 1))
|
||||
res_masks = self.reshape(res_masks, (1, reshape_size, 1))
|
||||
|
||||
all_bboxes += (res_boxes,)
|
||||
all_labels += (res_labels,)
|
||||
all_masks += (res_masks,)
|
||||
|
||||
all_bboxes = self.concat(all_bboxes)
|
||||
all_labels = self.concat(all_labels)
|
||||
all_masks = self.concat(all_masks)
|
||||
return all_bboxes, all_labels, all_masks
|
||||
|
||||
def get_anchors(self, featmap_sizes):
|
||||
"""Get anchors according to feature map sizes.
|
||||
|
||||
Args:
|
||||
featmap_sizes (list[tuple]): Multi-level feature map sizes.
|
||||
img_metas (list[dict]): Image meta info.
|
||||
|
||||
Returns:
|
||||
tuple: anchors of each image, valid flags of each image
|
||||
"""
|
||||
num_levels = len(featmap_sizes)
|
||||
|
||||
# since feature map sizes of all images are the same, we only compute
|
||||
# anchors for one time
|
||||
multi_level_anchors = ()
|
||||
for i in range(num_levels):
|
||||
anchors = self.anchor_generators[i].grid_anchors(
|
||||
featmap_sizes[i], self.anchor_strides[i])
|
||||
multi_level_anchors += (Tensor(anchors.astype(self.dtype)),)
|
||||
|
||||
return multi_level_anchors
|
||||
|
||||
class FasterRcnn_Infer(nn.Cell):
|
||||
def __init__(self, config):
|
||||
super(FasterRcnn_Infer, self).__init__()
|
||||
self.network = Faster_Rcnn_Resnet(config)
|
||||
self.network.set_train(False)
|
||||
|
||||
def construct(self, img_data, img_metas):
|
||||
output = self.network(img_data, img_metas, None, None, None)
|
||||
return output
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn feature pyramid network."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
||||
def bias_init_zeros(shape):
|
||||
"""Bias init method."""
|
||||
return Tensor(np.array(np.zeros(shape).astype(np.float32)))
|
||||
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
|
||||
"""Conv2D wrapper."""
|
||||
shape = (out_channels, in_channels, kernel_size, kernel_size)
|
||||
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32).to_tensor()
|
||||
shape_bias = (out_channels,)
|
||||
biass = bias_init_zeros(shape_bias)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass)
|
||||
|
||||
|
||||
class FeatPyramidNeck(nn.Cell):
|
||||
"""
|
||||
Feature pyramid network cell, usually uses as network neck.
|
||||
|
||||
Applies the convolution on multiple, input feature maps
|
||||
and output feature map with same channel size. if required num of
|
||||
output larger then num of inputs, add extra maxpooling for further
|
||||
downsampling;
|
||||
|
||||
Args:
|
||||
in_channels (tuple) - Channel size of input feature maps.
|
||||
out_channels (int) - Channel size output.
|
||||
num_outs (int) - Num of output features.
|
||||
|
||||
Returns:
|
||||
Tuple, with tensors of same channel size.
|
||||
|
||||
Examples:
|
||||
neck = FeatPyramidNeck([100,200,300], 50, 4)
|
||||
input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)),
|
||||
dtype=np.float32) \
|
||||
for i, c in enumerate(config.fpn_in_channels))
|
||||
x = neck(input_data)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_outs):
|
||||
super(FeatPyramidNeck, self).__init__()
|
||||
self.num_outs = num_outs
|
||||
self.in_channels = in_channels
|
||||
self.fpn_layer = len(self.in_channels)
|
||||
|
||||
assert not self.num_outs < len(in_channels)
|
||||
|
||||
self.lateral_convs_list_ = []
|
||||
self.fpn_convs_ = []
|
||||
|
||||
for _, channel in enumerate(in_channels):
|
||||
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid')
|
||||
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same')
|
||||
self.lateral_convs_list_.append(l_conv)
|
||||
self.fpn_convs_.append(fpn_conv)
|
||||
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
|
||||
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
|
||||
self.interpolate1 = P.ResizeNearestNeighbor((48, 80))
|
||||
self.interpolate2 = P.ResizeNearestNeighbor((96, 160))
|
||||
self.interpolate3 = P.ResizeNearestNeighbor((192, 320))
|
||||
self.maxpool = P.MaxPool(kernel_size=1, strides=2, pad_mode="same")
|
||||
|
||||
def construct(self, inputs):
|
||||
x = ()
|
||||
for i in range(self.fpn_layer):
|
||||
x += (self.lateral_convs_list[i](inputs[i]),)
|
||||
|
||||
y = (x[3],)
|
||||
y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),)
|
||||
y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),)
|
||||
y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),)
|
||||
|
||||
z = ()
|
||||
for i in range(self.fpn_layer - 1, -1, -1):
|
||||
z = z + (y[i],)
|
||||
|
||||
outs = ()
|
||||
for i in range(self.fpn_layer):
|
||||
outs = outs + (self.fpn_convs_list[i](z[i]),)
|
||||
|
||||
for i in range(self.num_outs - self.fpn_layer):
|
||||
outs = outs + (self.maxpool(outs[3]),)
|
||||
return outs
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn proposal generator."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Proposal(nn.Cell):
|
||||
"""
|
||||
Proposal subnet.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_classes (int) - Class number.
|
||||
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
|
||||
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
|
||||
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor,(proposal, mask).
|
||||
|
||||
Examples:
|
||||
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
|
||||
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
batch_size,
|
||||
num_classes,
|
||||
use_sigmoid_cls,
|
||||
target_means=(.0, .0, .0, .0),
|
||||
target_stds=(1.0, 1.0, 1.0, 1.0)
|
||||
):
|
||||
super(Proposal, self).__init__()
|
||||
cfg = config
|
||||
self.batch_size = batch_size
|
||||
self.num_classes = num_classes
|
||||
self.target_means = target_means
|
||||
self.target_stds = target_stds
|
||||
self.use_sigmoid_cls = use_sigmoid_cls
|
||||
|
||||
if self.use_sigmoid_cls:
|
||||
self.cls_out_channels = num_classes - 1
|
||||
self.activation = P.Sigmoid()
|
||||
self.reshape_shape = (-1, 1)
|
||||
else:
|
||||
self.cls_out_channels = num_classes
|
||||
self.activation = P.Softmax(axis=1)
|
||||
self.reshape_shape = (-1, 2)
|
||||
|
||||
if self.cls_out_channels <= 0:
|
||||
raise ValueError('num_classes={} is too small'.format(num_classes))
|
||||
|
||||
self.num_pre = cfg.rpn_proposal_nms_pre
|
||||
self.min_box_size = cfg.rpn_proposal_min_bbox_size
|
||||
self.nms_thr = cfg.rpn_proposal_nms_thr
|
||||
self.nms_post = cfg.rpn_proposal_nms_post
|
||||
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
|
||||
self.max_num = cfg.rpn_proposal_max_num
|
||||
self.num_levels = cfg.fpn_num_outs
|
||||
|
||||
# Op Define
|
||||
self.squeeze = P.Squeeze()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.feature_shapes = cfg.feature_shapes
|
||||
|
||||
self.transpose_shape = (1, 2, 0)
|
||||
|
||||
self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \
|
||||
means=self.target_means, \
|
||||
stds=self.target_stds)
|
||||
|
||||
self.nms = P.NMSWithMask(self.nms_thr)
|
||||
self.concat_axis0 = P.Concat(axis=0)
|
||||
self.concat_axis1 = P.Concat(axis=1)
|
||||
self.split = P.Split(axis=1, output_num=5)
|
||||
self.min = P.Minimum()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.slice = P.Slice()
|
||||
self.select = P.Select()
|
||||
self.greater = P.Greater()
|
||||
self.transpose = P.Transpose()
|
||||
self.tile = P.Tile()
|
||||
self.set_train_local(config, training=True)
|
||||
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
|
||||
self.multi_10 = Tensor(10.0, self.ms_type)
|
||||
|
||||
def set_train_local(self, config, training=True):
|
||||
"""Set training flag."""
|
||||
self.training_local = training
|
||||
|
||||
cfg = config
|
||||
self.topK_stage1 = ()
|
||||
self.topK_shape = ()
|
||||
total_max_topk_input = 0
|
||||
if not self.training_local:
|
||||
self.num_pre = cfg.rpn_nms_pre
|
||||
self.min_box_size = cfg.rpn_min_bbox_min_size
|
||||
self.nms_thr = cfg.rpn_nms_thr
|
||||
self.nms_post = cfg.rpn_nms_post
|
||||
self.nms_across_levels = cfg.rpn_nms_across_levels
|
||||
self.max_num = cfg.rpn_max_num
|
||||
|
||||
for shp in self.feature_shapes:
|
||||
k_num = min(self.num_pre, (shp[0] * shp[1] * 3))
|
||||
total_max_topk_input += k_num
|
||||
self.topK_stage1 += (k_num,)
|
||||
self.topK_shape += ((k_num, 1),)
|
||||
|
||||
self.topKv2 = P.TopK(sorted=True)
|
||||
self.topK_shape_stage2 = (self.max_num, 1)
|
||||
self.min_float_num = -65500.0
|
||||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float32))
|
||||
|
||||
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
|
||||
proposals_tuple = ()
|
||||
masks_tuple = ()
|
||||
for img_id in range(self.batch_size):
|
||||
cls_score_list = ()
|
||||
bbox_pred_list = ()
|
||||
for i in range(self.num_levels):
|
||||
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::])
|
||||
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::])
|
||||
|
||||
cls_score_list = cls_score_list + (rpn_cls_score_i,)
|
||||
bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,)
|
||||
|
||||
proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list)
|
||||
proposals_tuple += (proposals,)
|
||||
masks_tuple += (masks,)
|
||||
return proposals_tuple, masks_tuple
|
||||
|
||||
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
|
||||
"""Get proposal boundingbox."""
|
||||
mlvl_proposals = ()
|
||||
mlvl_mask = ()
|
||||
for idx in range(self.num_levels):
|
||||
rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape)
|
||||
rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape)
|
||||
anchors = mlvl_anchors[idx]
|
||||
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
|
||||
rpn_cls_score = self.activation(rpn_cls_score)
|
||||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), self.ms_type)
|
||||
|
||||
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), self.ms_type)
|
||||
|
||||
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx])
|
||||
|
||||
topk_inds = self.reshape(topk_inds, self.topK_shape[idx])
|
||||
|
||||
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
|
||||
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), self.ms_type)
|
||||
|
||||
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
|
||||
|
||||
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx])))
|
||||
proposals, _, mask_valid = self.nms(proposals_decode)
|
||||
|
||||
mlvl_proposals = mlvl_proposals + (proposals,)
|
||||
mlvl_mask = mlvl_mask + (mask_valid,)
|
||||
|
||||
proposals = self.concat_axis0(mlvl_proposals)
|
||||
masks = self.concat_axis0(mlvl_mask)
|
||||
|
||||
_, _, _, _, scores = self.split(proposals)
|
||||
scores = self.squeeze(scores)
|
||||
topk_mask = self.cast(self.topK_mask, self.ms_type)
|
||||
scores_using = self.select(masks, scores, topk_mask)
|
||||
|
||||
_, topk_inds = self.topKv2(scores_using, self.max_num)
|
||||
|
||||
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
|
||||
proposals = self.gatherND(proposals, topk_inds)
|
||||
masks = self.gatherND(masks, topk_inds)
|
||||
return proposals, masks
|
|
@ -0,0 +1,179 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn Rcnn network."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class DenseNoTranpose(nn.Cell):
|
||||
"""Dense method"""
|
||||
def __init__(self, input_channels, output_channels, weight_init):
|
||||
super(DenseNoTranpose, self).__init__()
|
||||
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float32))
|
||||
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float32))
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
|
||||
def construct(self, x):
|
||||
if self.device_type == "Ascend":
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
output = self.bias_add(self.matmul(x, weight), self.bias)
|
||||
else:
|
||||
output = self.bias_add(self.matmul(x, self.weight), self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class Rcnn(nn.Cell):
|
||||
"""
|
||||
Rcnn subnet.
|
||||
|
||||
Args:
|
||||
config (dict) - Config.
|
||||
representation_size (int) - Channels of shared dense.
|
||||
batch_size (int) - Batchsize.
|
||||
num_classes (int) - Class number.
|
||||
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
|
||||
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor.
|
||||
|
||||
Examples:
|
||||
Rcnn(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
|
||||
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
representation_size,
|
||||
batch_size,
|
||||
num_classes,
|
||||
target_means=(0., 0., 0., 0.),
|
||||
target_stds=(0.1, 0.1, 0.2, 0.2)
|
||||
):
|
||||
super(Rcnn, self).__init__()
|
||||
cfg = config
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(self.dtype))
|
||||
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(self.dtype))
|
||||
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
|
||||
self.target_means = target_means
|
||||
self.target_stds = target_stds
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = cfg.rcnn_in_channels
|
||||
self.train_batch_size = batch_size
|
||||
self.test_batch_size = cfg.test_batch_size
|
||||
|
||||
shape_0 = (self.rcnn_fc_out_channels, representation_size)
|
||||
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=self.ms_type).to_tensor()
|
||||
shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels)
|
||||
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=self.ms_type).to_tensor()
|
||||
self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0)
|
||||
self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1)
|
||||
|
||||
cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1],
|
||||
dtype=self.ms_type).to_tensor()
|
||||
reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1],
|
||||
dtype=self.ms_type).to_tensor()
|
||||
self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight)
|
||||
self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight)
|
||||
|
||||
self.flatten = P.Flatten()
|
||||
self.relu = P.ReLU()
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.loss_cls = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.loss_bbox = P.SmoothL1Loss(beta=1.0)
|
||||
self.reshape = P.Reshape()
|
||||
self.onehot = P.OneHot()
|
||||
self.greater = P.Greater()
|
||||
self.cast = P.Cast()
|
||||
self.sum_loss = P.ReduceSum()
|
||||
self.tile = P.Tile()
|
||||
self.expandims = P.ExpandDims()
|
||||
|
||||
self.gather = P.GatherNd()
|
||||
self.argmax = P.ArgMaxWithValue(axis=1)
|
||||
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.value = Tensor(1.0, self.ms_type)
|
||||
|
||||
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size
|
||||
|
||||
rmv_first = np.ones((self.num_bboxes, self.num_classes))
|
||||
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
|
||||
self.rmv_first_tensor = Tensor(rmv_first.astype(self.dtype))
|
||||
|
||||
self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size
|
||||
|
||||
range_max = np.arange(self.num_bboxes_test).astype(np.int32)
|
||||
self.range_max = Tensor(range_max)
|
||||
|
||||
def construct(self, featuremap, bbox_targets, labels, mask):
|
||||
x = self.flatten(featuremap)
|
||||
|
||||
x = self.relu(self.shared_fc_0(x))
|
||||
x = self.relu(self.shared_fc_1(x))
|
||||
|
||||
x_cls = self.cls_scores(x)
|
||||
x_reg = self.reg_scores(x)
|
||||
|
||||
if self.training:
|
||||
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
|
||||
labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value)
|
||||
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
|
||||
|
||||
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)
|
||||
out = (loss, loss_cls, loss_reg, loss_print)
|
||||
else:
|
||||
out = (x_cls, (x_cls / self.value), x_reg, x_cls)
|
||||
|
||||
return out
|
||||
|
||||
def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights):
|
||||
"""Loss method."""
|
||||
loss_print = ()
|
||||
loss_cls, _ = self.loss_cls(cls_score, labels)
|
||||
|
||||
weights = self.cast(weights, self.ms_type)
|
||||
loss_cls = loss_cls * weights
|
||||
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
|
||||
|
||||
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
|
||||
self.ms_type)
|
||||
bbox_weights = bbox_weights * self.rmv_first_tensor
|
||||
|
||||
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))
|
||||
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
|
||||
loss_reg = self.sum_loss(loss_reg, (2,))
|
||||
loss_reg = loss_reg * bbox_weights
|
||||
loss_reg = loss_reg / self.sum_loss(weights, (0,))
|
||||
loss_reg = self.sum_loss(loss_reg, (0, 1))
|
||||
|
||||
loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg
|
||||
loss_print += (loss_cls, loss_reg)
|
||||
|
||||
return loss, loss_cls, loss_reg, loss_print
|
|
@ -0,0 +1,262 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Resnet backbone."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
def weight_init_ones(shape):
|
||||
"""Weight init."""
|
||||
return Tensor(np.full(shape, 0.01).astype(np.float32))
|
||||
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
|
||||
"""Conv2D wrapper."""
|
||||
shape = (out_channels, in_channels, kernel_size, kernel_size)
|
||||
weights = weight_init_ones(shape)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=False)
|
||||
|
||||
|
||||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
|
||||
"""Batchnorm2D wrapper."""
|
||||
dtype = np.float32
|
||||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
|
||||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
|
||||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
|
||||
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
|
||||
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
|
||||
beta_init=beta_init, moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
|
||||
|
||||
|
||||
class ResNetFea(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
weights_update (bool): Weight update flag.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> False)
|
||||
"""
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weights_update=False):
|
||||
super(ResNetFea, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of "
|
||||
"layer_num, inchannel, outchannel list must be 4!")
|
||||
|
||||
bn_training = False
|
||||
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
|
||||
self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = P.MaxPool(kernel_size=3, strides=2, pad_mode="SAME")
|
||||
self.weights_update = weights_update
|
||||
|
||||
if not self.weights_update:
|
||||
self.conv1.weight.requires_grad = False
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=1,
|
||||
training=bn_training,
|
||||
weights_update=self.weights_update)
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=2,
|
||||
training=bn_training,
|
||||
weights_update=True)
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=2,
|
||||
training=bn_training,
|
||||
weights_update=True)
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=2,
|
||||
training=bn_training,
|
||||
weights_update=True)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False):
|
||||
"""Make block layer."""
|
||||
layers = []
|
||||
down_sample = False
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
down_sample = True
|
||||
resblk = block(in_channel,
|
||||
out_channel,
|
||||
stride=stride,
|
||||
down_sample=down_sample,
|
||||
training=training,
|
||||
weights_update=weights_update)
|
||||
layers.append(resblk)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update)
|
||||
layers.append(resblk)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
construct the ResNet Network
|
||||
|
||||
Args:
|
||||
x: input feature data.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
identity = c2
|
||||
if not self.weights_update:
|
||||
identity = F.stop_gradient(c2)
|
||||
c3 = self.layer2(identity)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
return identity, c3, c4, c5
|
||||
|
||||
|
||||
class ResidualBlockUsing(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channels (int) - Input channel.
|
||||
out_channels (int) - Output channel.
|
||||
stride (int) - Stride size for the initial convolutional layer. Default: 1.
|
||||
down_sample (bool) - If to do the downsample in block. Default: False.
|
||||
momentum (float) - Momentum for batchnorm layer. Default: 0.1.
|
||||
training (bool) - Training flag. Default: False.
|
||||
weights_updata (bool) - Weights update flag. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
ResidualBlock(3,256,stride=2,down_sample=True)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
down_sample=False,
|
||||
momentum=0.1,
|
||||
training=False,
|
||||
weights_update=False):
|
||||
super(ResidualBlockUsing, self).__init__()
|
||||
|
||||
self.affine = weights_update
|
||||
|
||||
out_chls = out_channels // self.expansion
|
||||
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0)
|
||||
self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
|
||||
|
||||
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1)
|
||||
self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
|
||||
|
||||
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training)
|
||||
|
||||
if training:
|
||||
self.bn1 = self.bn1.set_train()
|
||||
self.bn2 = self.bn2.set_train()
|
||||
self.bn3 = self.bn3.set_train()
|
||||
|
||||
if not weights_update:
|
||||
self.conv1.weight.requires_grad = False
|
||||
self.conv2.weight.requires_grad = False
|
||||
self.conv3.weight.requires_grad = False
|
||||
|
||||
self.relu = P.ReLU()
|
||||
self.downsample = down_sample
|
||||
if self.downsample:
|
||||
self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
|
||||
self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine,
|
||||
use_batch_statistics=training)
|
||||
if training:
|
||||
self.bn_down_sample = self.bn_down_sample.set_train()
|
||||
if not weights_update:
|
||||
self.conv_down_sample.weight.requires_grad = False
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
construct the ResNet V1 residual block
|
||||
|
||||
Args:
|
||||
x: input feature data.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample:
|
||||
identity = self.conv_down_sample(identity)
|
||||
identity = self.bn_down_sample(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
|
@ -0,0 +1,264 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Resnet50v1.0 backbone."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
def weight_init_ones(shape):
|
||||
"""Weight init."""
|
||||
return Tensor(np.full(shape, 0.01).astype(np.float32))
|
||||
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
|
||||
"""Conv2D wrapper."""
|
||||
shape = (out_channels, in_channels, kernel_size, kernel_size)
|
||||
weights = weight_init_ones(shape)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=False)
|
||||
|
||||
|
||||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
|
||||
"""Batchnorm2D wrapper."""
|
||||
dtype = np.float32
|
||||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
|
||||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
|
||||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
|
||||
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
|
||||
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
|
||||
beta_init=beta_init, moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
|
||||
|
||||
|
||||
class ResNetFea(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
weights_update (bool): Weight update flag.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> False)
|
||||
"""
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weights_update=False):
|
||||
super(ResNetFea, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of "
|
||||
"layer_num, inchannel, outchannel list must be 4!")
|
||||
|
||||
bn_training = False
|
||||
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
|
||||
self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = P.MaxPool(kernel_size=3, strides=2, pad_mode="SAME")
|
||||
self.weights_update = weights_update
|
||||
|
||||
if not self.weights_update:
|
||||
self.conv1.weight.requires_grad = False
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=1,
|
||||
training=bn_training,
|
||||
weights_update=self.weights_update)
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=2,
|
||||
training=bn_training,
|
||||
weights_update=True)
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=2,
|
||||
training=bn_training,
|
||||
weights_update=True)
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=2,
|
||||
training=bn_training,
|
||||
weights_update=True)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False):
|
||||
"""Make block layer."""
|
||||
layers = []
|
||||
down_sample = False
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
down_sample = True
|
||||
resblk = block(in_channel,
|
||||
out_channel,
|
||||
stride=stride,
|
||||
down_sample=down_sample,
|
||||
training=training,
|
||||
weights_update=weights_update)
|
||||
layers.append(resblk)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update)
|
||||
layers.append(resblk)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
construct the ResNet Network
|
||||
|
||||
Args:
|
||||
x: input feature data.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
identity = c2
|
||||
if not self.weights_update:
|
||||
identity = F.stop_gradient(c2)
|
||||
c3 = self.layer2(identity)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
return identity, c3, c4, c5
|
||||
|
||||
|
||||
class ResidualBlockUsing_V1(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channels (int) - Input channel.
|
||||
out_channels (int) - Output channel.
|
||||
stride (int) - Stride size for the initial convolutional layer. Default: 1.
|
||||
down_sample (bool) - If to do the downsample in block. Default: False.
|
||||
momentum (float) - Momentum for batchnorm layer. Default: 0.1.
|
||||
training (bool) - Training flag. Default: False.
|
||||
weights_updata (bool) - Weights update flag. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
ResidualBlock(3,256,stride=2,down_sample=True)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
down_sample=False,
|
||||
momentum=0.1,
|
||||
training=False,
|
||||
weights_update=False):
|
||||
super(ResidualBlockUsing_V1, self).__init__()
|
||||
|
||||
self.affine = weights_update
|
||||
|
||||
out_chls = out_channels // self.expansion
|
||||
# self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0)
|
||||
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=stride, padding=0)
|
||||
self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
|
||||
|
||||
# self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1)
|
||||
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=1, padding=1)
|
||||
self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
|
||||
|
||||
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training)
|
||||
|
||||
if training:
|
||||
self.bn1 = self.bn1.set_train()
|
||||
self.bn2 = self.bn2.set_train()
|
||||
self.bn3 = self.bn3.set_train()
|
||||
|
||||
if not weights_update:
|
||||
self.conv1.weight.requires_grad = False
|
||||
self.conv2.weight.requires_grad = False
|
||||
self.conv3.weight.requires_grad = False
|
||||
|
||||
self.relu = P.ReLU()
|
||||
self.downsample = down_sample
|
||||
if self.downsample:
|
||||
self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
|
||||
self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine,
|
||||
use_batch_statistics=training)
|
||||
if training:
|
||||
self.bn_down_sample = self.bn_down_sample.set_train()
|
||||
if not weights_update:
|
||||
self.conv_down_sample.weight.requires_grad = False
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
construct the ResNet V1 residual block
|
||||
|
||||
Args:
|
||||
x: input feature data.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample:
|
||||
identity = self.conv_down_sample(identity)
|
||||
identity = self.bn_down_sample(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
|
@ -0,0 +1,179 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn ROIAlign module."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.nn import layer as L
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
class ROIAlign(nn.Cell):
|
||||
"""
|
||||
Extract RoI features from multiple feature map.
|
||||
|
||||
Args:
|
||||
out_size_h (int) - RoI height.
|
||||
out_size_w (int) - RoI width.
|
||||
spatial_scale (int) - RoI spatial scale.
|
||||
sample_num (int) - RoI sample number.
|
||||
"""
|
||||
def __init__(self,
|
||||
out_size_h,
|
||||
out_size_w,
|
||||
spatial_scale,
|
||||
sample_num=0):
|
||||
super(ROIAlign, self).__init__()
|
||||
|
||||
self.out_size = (out_size_h, out_size_w)
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.sample_num = int(sample_num)
|
||||
self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1],
|
||||
self.spatial_scale, self.sample_num)
|
||||
|
||||
def construct(self, features, rois):
|
||||
return self.align_op(features, rois)
|
||||
|
||||
def __repr__(self):
|
||||
format_str = self.__class__.__name__
|
||||
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format(
|
||||
self.out_size, self.spatial_scale, self.sample_num)
|
||||
return format_str
|
||||
|
||||
|
||||
class SingleRoIExtractor(nn.Cell):
|
||||
"""
|
||||
Extract RoI features from a single level feature map.
|
||||
|
||||
If there are multiple input feature levels, each RoI is mapped to a level
|
||||
according to its scale.
|
||||
|
||||
Args:
|
||||
config (dict): Config
|
||||
roi_layer (dict): Specify RoI layer type and arguments.
|
||||
out_channels (int): Output channels of RoI layers.
|
||||
featmap_strides (int): Strides of input feature maps.
|
||||
batch_size (int): Batchsize.
|
||||
finest_scale (int): Scale threshold of mapping to level 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
roi_layer,
|
||||
out_channels,
|
||||
featmap_strides,
|
||||
batch_size=1,
|
||||
finest_scale=56):
|
||||
super(SingleRoIExtractor, self).__init__()
|
||||
cfg = config
|
||||
self.train_batch_size = batch_size
|
||||
self.out_channels = out_channels
|
||||
self.featmap_strides = featmap_strides
|
||||
self.num_levels = len(self.featmap_strides)
|
||||
self.out_size = config.roi_layer.out_size
|
||||
self.sample_num = config.roi_layer.sample_num
|
||||
self.roi_layers = self.build_roi_layers(self.featmap_strides)
|
||||
self.roi_layers = L.CellList(self.roi_layers)
|
||||
|
||||
self.sqrt = P.Sqrt()
|
||||
self.log = P.Log()
|
||||
self.finest_scale_ = finest_scale
|
||||
self.clamp = C.clip_by_value
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.equal = P.Equal()
|
||||
self.select = P.Select()
|
||||
|
||||
_mode_16 = False
|
||||
self.dtype = np.float16 if _mode_16 else np.float32
|
||||
self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32
|
||||
self.set_train_local(cfg, training=True)
|
||||
|
||||
def set_train_local(self, config, training=True):
|
||||
"""Set training flag."""
|
||||
self.training_local = training
|
||||
|
||||
cfg = config
|
||||
# Init tensor
|
||||
self.batch_size = cfg.roi_sample_num if self.training_local else cfg.rpn_max_num
|
||||
self.batch_size = self.train_batch_size*self.batch_size \
|
||||
if self.training_local else cfg.test_batch_size*self.batch_size
|
||||
self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype))
|
||||
finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_
|
||||
self.finest_scale = Tensor(finest_scale)
|
||||
self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6))
|
||||
self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32))
|
||||
self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1))
|
||||
self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2)
|
||||
self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels,
|
||||
self.out_size, self.out_size)), dtype=self.dtype))
|
||||
def num_inputs(self):
|
||||
return len(self.featmap_strides)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def log2(self, value):
|
||||
return self.log(value) / self.log(self.twos)
|
||||
|
||||
def build_roi_layers(self, featmap_strides):
|
||||
roi_layers = []
|
||||
for s in featmap_strides:
|
||||
layer_cls = ROIAlign(self.out_size, self.out_size,
|
||||
spatial_scale=1 / s,
|
||||
sample_num=self.sample_num)
|
||||
roi_layers.append(layer_cls)
|
||||
return roi_layers
|
||||
|
||||
def _c_map_roi_levels(self, rois):
|
||||
"""Map rois to corresponding feature levels by scales.
|
||||
|
||||
- scale < finest_scale * 2: level 0
|
||||
- finest_scale * 2 <= scale < finest_scale * 4: level 1
|
||||
- finest_scale * 4 <= scale < finest_scale * 8: level 2
|
||||
- scale >= finest_scale * 8: level 3
|
||||
|
||||
Args:
|
||||
rois (Tensor): Input RoIs, shape (k, 5).
|
||||
num_levels (int): Total level number.
|
||||
|
||||
Returns:
|
||||
Tensor: Level index (0-based) of each RoI, shape (k, )
|
||||
"""
|
||||
scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \
|
||||
self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones)
|
||||
|
||||
target_lvls = self.log2(scale / self.finest_scale + self.epslion)
|
||||
target_lvls = P.Floor()(target_lvls)
|
||||
target_lvls = self.cast(target_lvls, mstype.int32)
|
||||
target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels)
|
||||
|
||||
return target_lvls
|
||||
|
||||
def construct(self, rois, feat1, feat2, feat3, feat4):
|
||||
feats = (feat1, feat2, feat3, feat4)
|
||||
res = self.res_
|
||||
target_lvls = self._c_map_roi_levels(rois)
|
||||
for i in range(self.num_levels):
|
||||
mask = self.equal(target_lvls, P.ScalarToArray()(i))
|
||||
mask = P.Reshape()(mask, (-1, 1, 1, 1))
|
||||
roi_feats_t = self.roi_layers[i](feats[i], rois)
|
||||
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32),\
|
||||
(1, 256, self.out_size, self.out_size)), mstype.bool_)
|
||||
res = self.select(mask, roi_feats_t, res)
|
||||
|
||||
return res
|
|
@ -0,0 +1,318 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""RPN for fasterRCNN"""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import initializer
|
||||
from .bbox_assign_sample import BboxAssignSample
|
||||
|
||||
|
||||
class RpnRegClsBlock(nn.Cell):
|
||||
"""
|
||||
Rpn reg cls block for rpn layer
|
||||
|
||||
Args:
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
weight_conv (Tensor) - weight init for rpn conv.
|
||||
bias_conv (Tensor) - bias init for rpn conv.
|
||||
weight_cls (Tensor) - weight init for rpn cls conv.
|
||||
bias_cls (Tensor) - bias init for rpn cls conv.
|
||||
weight_reg (Tensor) - weight init for rpn reg conv.
|
||||
bias_reg (Tensor) - bias init for rpn reg conv.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
num_anchors,
|
||||
cls_out_channels,
|
||||
weight_conv,
|
||||
bias_conv,
|
||||
weight_cls,
|
||||
bias_cls,
|
||||
weight_reg,
|
||||
bias_reg):
|
||||
super(RpnRegClsBlock, self).__init__()
|
||||
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True, weight_init=weight_conv, bias_init=bias_conv)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid',
|
||||
has_bias=True, weight_init=weight_cls, bias_init=bias_cls)
|
||||
self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid',
|
||||
has_bias=True, weight_init=weight_reg, bias_init=bias_reg)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.rpn_conv(x))
|
||||
|
||||
x1 = self.rpn_cls(x)
|
||||
x2 = self.rpn_reg(x)
|
||||
|
||||
return x1, x2
|
||||
|
||||
|
||||
class RPN(nn.Cell):
|
||||
"""
|
||||
ROI proposal network..
|
||||
|
||||
Args:
|
||||
config (dict) - Config.
|
||||
batch_size (int) - Batchsize.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor.
|
||||
|
||||
Examples:
|
||||
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
|
||||
num_anchors=3, cls_out_channels=512)
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
batch_size,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
num_anchors,
|
||||
cls_out_channels):
|
||||
super(RPN, self).__init__()
|
||||
cfg_rpn = config
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.slice_index = ()
|
||||
self.feature_anchor_shape = ()
|
||||
self.slice_index += (0,)
|
||||
index = 0
|
||||
for shape in cfg_rpn.feature_shapes:
|
||||
self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,)
|
||||
self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,)
|
||||
index += 1
|
||||
|
||||
self.num_anchors = num_anchors
|
||||
self.batch_size = batch_size
|
||||
self.test_batch_size = cfg_rpn.test_batch_size
|
||||
self.num_layers = 5
|
||||
self.real_ratio = Tensor(np.ones((1, 1)).astype(self.dtype))
|
||||
|
||||
self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels,
|
||||
num_anchors, cls_out_channels))
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.fill = P.Fill()
|
||||
self.placeh1 = Tensor(np.ones((1,)).astype(self.dtype))
|
||||
|
||||
self.trans_shape = (0, 2, 3, 1)
|
||||
|
||||
self.reshape_shape_reg = (-1, 4)
|
||||
self.reshape_shape_cls = (-1,)
|
||||
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(self.dtype))
|
||||
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(self.dtype))
|
||||
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(self.dtype))
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
|
||||
self.CheckValid = P.CheckValid()
|
||||
self.sum_loss = P.ReduceSum()
|
||||
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
|
||||
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
self.loss = Tensor(np.zeros((1,)).astype(self.dtype))
|
||||
self.clsloss = Tensor(np.zeros((1,)).astype(self.dtype))
|
||||
self.regloss = Tensor(np.zeros((1,)).astype(self.dtype))
|
||||
|
||||
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
|
||||
"""
|
||||
make rpn layer for rpn proposal network
|
||||
|
||||
Args:
|
||||
num_layers (int) - layer num.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
List, list of RpnRegClsBlock cells.
|
||||
"""
|
||||
rpn_layer = []
|
||||
|
||||
shp_weight_conv = (feat_channels, in_channels, 3, 3)
|
||||
shp_bias_conv = (feat_channels,)
|
||||
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=self.ms_type).to_tensor()
|
||||
bias_conv = initializer(0, shape=shp_bias_conv, dtype=self.ms_type).to_tensor()
|
||||
|
||||
shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
|
||||
shp_bias_cls = (num_anchors * cls_out_channels,)
|
||||
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=self.ms_type).to_tensor()
|
||||
bias_cls = initializer(0, shape=shp_bias_cls, dtype=self.ms_type).to_tensor()
|
||||
|
||||
shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
|
||||
shp_bias_reg = (num_anchors * 4,)
|
||||
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=self.ms_type).to_tensor()
|
||||
bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor()
|
||||
|
||||
for i in range(num_layers):
|
||||
rpn_reg_cls_block = RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
|
||||
weight_conv, bias_conv, weight_cls, \
|
||||
bias_cls, weight_reg, bias_reg)
|
||||
if self.device_type == "Ascend":
|
||||
rpn_reg_cls_block.to_float(mstype.float16)
|
||||
rpn_layer.append(rpn_reg_cls_block)
|
||||
|
||||
for i in range(1, num_layers):
|
||||
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
|
||||
rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight
|
||||
rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight
|
||||
|
||||
rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias
|
||||
rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias
|
||||
rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias
|
||||
|
||||
return rpn_layer
|
||||
|
||||
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
|
||||
loss_print = ()
|
||||
rpn_cls_score = ()
|
||||
rpn_bbox_pred = ()
|
||||
rpn_cls_score_total = ()
|
||||
rpn_bbox_pred_total = ()
|
||||
|
||||
for i in range(self.num_layers):
|
||||
x1, x2 = self.rpn_convs_list[i](inputs[i])
|
||||
|
||||
rpn_cls_score_total = rpn_cls_score_total + (x1,)
|
||||
rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,)
|
||||
|
||||
x1 = self.transpose(x1, self.trans_shape)
|
||||
x1 = self.reshape(x1, self.reshape_shape_cls)
|
||||
|
||||
x2 = self.transpose(x2, self.trans_shape)
|
||||
x2 = self.reshape(x2, self.reshape_shape_reg)
|
||||
|
||||
rpn_cls_score = rpn_cls_score + (x1,)
|
||||
rpn_bbox_pred = rpn_bbox_pred + (x2,)
|
||||
|
||||
loss = self.loss
|
||||
clsloss = self.clsloss
|
||||
regloss = self.regloss
|
||||
bbox_targets = ()
|
||||
bbox_weights = ()
|
||||
labels = ()
|
||||
label_weights = ()
|
||||
|
||||
output = ()
|
||||
if self.training:
|
||||
for i in range(self.batch_size):
|
||||
multi_level_flags = ()
|
||||
anchor_list_tuple = ()
|
||||
|
||||
for j in range(self.num_layers):
|
||||
res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])),
|
||||
mstype.int32)
|
||||
multi_level_flags = multi_level_flags + (res,)
|
||||
anchor_list_tuple = anchor_list_tuple + (anchor_list[j],)
|
||||
|
||||
valid_flag_list = self.concat(multi_level_flags)
|
||||
anchor_using_list = self.concat(anchor_list_tuple)
|
||||
|
||||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
|
||||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
|
||||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
|
||||
|
||||
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
|
||||
gt_labels_i,
|
||||
self.cast(valid_flag_list,
|
||||
mstype.bool_),
|
||||
anchor_using_list, gt_valids_i)
|
||||
|
||||
bbox_target = self.cast(bbox_target, self.ms_type)
|
||||
bbox_weight = self.cast(bbox_weight, self.ms_type)
|
||||
label = self.cast(label, self.ms_type)
|
||||
label_weight = self.cast(label_weight, self.ms_type)
|
||||
|
||||
for j in range(self.num_layers):
|
||||
begin = self.slice_index[j]
|
||||
end = self.slice_index[j + 1]
|
||||
stride = 1
|
||||
bbox_targets += (bbox_target[begin:end:stride, ::],)
|
||||
bbox_weights += (bbox_weight[begin:end:stride],)
|
||||
labels += (label[begin:end:stride],)
|
||||
label_weights += (label_weight[begin:end:stride],)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
bbox_target_using = ()
|
||||
bbox_weight_using = ()
|
||||
label_using = ()
|
||||
label_weight_using = ()
|
||||
|
||||
for j in range(self.batch_size):
|
||||
bbox_target_using += (bbox_targets[i + (self.num_layers * j)],)
|
||||
bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],)
|
||||
label_using += (labels[i + (self.num_layers * j)],)
|
||||
label_weight_using += (label_weights[i + (self.num_layers * j)],)
|
||||
|
||||
bbox_target_with_batchsize = self.concat(bbox_target_using)
|
||||
bbox_weight_with_batchsize = self.concat(bbox_weight_using)
|
||||
label_with_batchsize = self.concat(label_using)
|
||||
label_weight_with_batchsize = self.concat(label_weight_using)
|
||||
|
||||
# stop
|
||||
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
|
||||
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
|
||||
label_ = F.stop_gradient(label_with_batchsize)
|
||||
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
|
||||
|
||||
cls_score_i = self.cast(rpn_cls_score[i], self.ms_type)
|
||||
reg_score_i = self.cast(rpn_bbox_pred[i], self.ms_type)
|
||||
|
||||
loss_cls = self.loss_cls(cls_score_i, label_)
|
||||
loss_cls_item = loss_cls * label_weight_
|
||||
loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total
|
||||
|
||||
loss_reg = self.loss_bbox(reg_score_i, bbox_target_)
|
||||
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4))
|
||||
loss_reg = loss_reg * bbox_weight_
|
||||
loss_reg_item = self.sum_loss(loss_reg, (1,))
|
||||
loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total
|
||||
|
||||
loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item
|
||||
|
||||
loss += loss_total
|
||||
loss_print += (loss_total, loss_cls_item, loss_reg_item)
|
||||
clsloss += loss_cls_item
|
||||
regloss += loss_reg_item
|
||||
|
||||
output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print)
|
||||
else:
|
||||
output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1)
|
||||
|
||||
return output
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""
|
||||
convert resnet pretrain model to faster_rcnn backbone pretrain model
|
||||
"""
|
||||
from mindspore.train.serialization import load_checkpoint, save_checkpoint
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from .model_utils.config import config
|
||||
|
||||
|
||||
def load_weights(model_path, use_fp16_weight):
|
||||
"""
|
||||
load resnet pretrain checkpoint file.
|
||||
|
||||
Args:
|
||||
model_path (str): resnet pretrain checkpoint file .
|
||||
use_fp16_weight(bool): whether save weight into float16.
|
||||
|
||||
Returns:
|
||||
parameter list(list): pretrain model weight list.
|
||||
"""
|
||||
ms_ckpt = load_checkpoint(model_path)
|
||||
weights = {}
|
||||
for msname in ms_ckpt:
|
||||
if msname.startswith("layer") or msname.startswith("conv1") or msname.startswith("bn"):
|
||||
param_name = "backbone." + msname
|
||||
else:
|
||||
param_name = msname
|
||||
if "down_sample_layer.0" in param_name:
|
||||
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
|
||||
if "down_sample_layer.1" in param_name:
|
||||
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
|
||||
weights[param_name] = ms_ckpt[msname].data.asnumpy()
|
||||
if use_fp16_weight:
|
||||
dtype = mstype.float16
|
||||
else:
|
||||
dtype = mstype.float32
|
||||
parameter_dict = {}
|
||||
for name in weights:
|
||||
parameter_dict[name] = Parameter(Tensor(weights[name], dtype), name=name)
|
||||
param_list = []
|
||||
for key, value in parameter_dict.items():
|
||||
param_list.append({"name": key, "data": value})
|
||||
return param_list
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameter_list = load_weights(config.ckpt_file, use_fp16_weight=False)
|
||||
save_checkpoint(parameter_list, "resnet_backbone.ckpt")
|
|
@ -0,0 +1,483 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FasterRcnn dataset"""
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from numpy import random
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
||||
"""Calculate the ious between each bbox of bboxes1 and bboxes2.
|
||||
|
||||
Args:
|
||||
bboxes1(ndarray): shape (n, 4)
|
||||
bboxes2(ndarray): shape (k, 4)
|
||||
mode(str): iou (intersection over union) or iof (intersection
|
||||
over foreground)
|
||||
|
||||
Returns:
|
||||
ious(ndarray): shape (n, k)
|
||||
"""
|
||||
|
||||
assert mode in ['iou', 'iof']
|
||||
|
||||
bboxes1 = bboxes1.astype(np.float32)
|
||||
bboxes2 = bboxes2.astype(np.float32)
|
||||
rows = bboxes1.shape[0]
|
||||
cols = bboxes2.shape[0]
|
||||
ious = np.zeros((rows, cols), dtype=np.float32)
|
||||
if rows * cols == 0:
|
||||
return ious
|
||||
exchange = False
|
||||
if bboxes1.shape[0] > bboxes2.shape[0]:
|
||||
bboxes1, bboxes2 = bboxes2, bboxes1
|
||||
ious = np.zeros((cols, rows), dtype=np.float32)
|
||||
exchange = True
|
||||
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1)
|
||||
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1)
|
||||
for i in range(bboxes1.shape[0]):
|
||||
x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
|
||||
y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
|
||||
x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
|
||||
y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
|
||||
overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
|
||||
y_end - y_start + 1, 0)
|
||||
if mode == 'iou':
|
||||
union = area1[i] + area2 - overlap
|
||||
else:
|
||||
union = area1[i] if not exchange else area2
|
||||
ious[i, :] = overlap / union
|
||||
if exchange:
|
||||
ious = ious.T
|
||||
return ious
|
||||
|
||||
|
||||
class PhotoMetricDistortion:
|
||||
"""Photo Metric Distortion"""
|
||||
def __init__(self,
|
||||
brightness_delta=32,
|
||||
contrast_range=(0.5, 1.5),
|
||||
saturation_range=(0.5, 1.5),
|
||||
hue_delta=18):
|
||||
self.brightness_delta = brightness_delta
|
||||
self.contrast_lower, self.contrast_upper = contrast_range
|
||||
self.saturation_lower, self.saturation_upper = saturation_range
|
||||
self.hue_delta = hue_delta
|
||||
|
||||
def __call__(self, img, boxes, labels):
|
||||
# random brightness
|
||||
img = img.astype('float32')
|
||||
|
||||
if random.randint(2):
|
||||
delta = random.uniform(-self.brightness_delta,
|
||||
self.brightness_delta)
|
||||
img += delta
|
||||
|
||||
# mode == 0 --> do random contrast first
|
||||
# mode == 1 --> do random contrast last
|
||||
mode = random.randint(2)
|
||||
if mode == 1:
|
||||
if random.randint(2):
|
||||
alpha = random.uniform(self.contrast_lower,
|
||||
self.contrast_upper)
|
||||
img *= alpha
|
||||
|
||||
# convert color from BGR to HSV
|
||||
img = mmcv.bgr2hsv(img)
|
||||
|
||||
# random saturation
|
||||
if random.randint(2):
|
||||
img[..., 1] *= random.uniform(self.saturation_lower,
|
||||
self.saturation_upper)
|
||||
|
||||
# random hue
|
||||
if random.randint(2):
|
||||
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
|
||||
img[..., 0][img[..., 0] > 360] -= 360
|
||||
img[..., 0][img[..., 0] < 0] += 360
|
||||
|
||||
# convert color from HSV to BGR
|
||||
img = mmcv.hsv2bgr(img)
|
||||
|
||||
# random contrast
|
||||
if mode == 0:
|
||||
if random.randint(2):
|
||||
alpha = random.uniform(self.contrast_lower,
|
||||
self.contrast_upper)
|
||||
img *= alpha
|
||||
|
||||
# randomly swap channels
|
||||
if random.randint(2):
|
||||
img = img[..., random.permutation(3)]
|
||||
|
||||
return img, boxes, labels
|
||||
|
||||
|
||||
class Expand:
|
||||
"""expand image"""
|
||||
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
|
||||
if to_rgb:
|
||||
self.mean = mean[::-1]
|
||||
else:
|
||||
self.mean = mean
|
||||
self.min_ratio, self.max_ratio = ratio_range
|
||||
|
||||
def __call__(self, img, boxes, labels):
|
||||
if random.randint(2):
|
||||
return img, boxes, labels
|
||||
|
||||
h, w, c = img.shape
|
||||
ratio = random.uniform(self.min_ratio, self.max_ratio)
|
||||
expand_img = np.full((int(h * ratio), int(w * ratio), c),
|
||||
self.mean).astype(img.dtype)
|
||||
left = int(random.uniform(0, w * ratio - w))
|
||||
top = int(random.uniform(0, h * ratio - h))
|
||||
expand_img[top:top + h, left:left + w] = img
|
||||
img = expand_img
|
||||
boxes += np.tile((left, top), 2)
|
||||
return img, boxes, labels
|
||||
|
||||
|
||||
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num, config):
|
||||
"""rescale operation for image"""
|
||||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
||||
if img_data.shape[0] > config.img_height:
|
||||
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_height), return_scale=True)
|
||||
scale_factor = scale_factor*scale_factor2
|
||||
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_data.shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_data.shape[0] - 1)
|
||||
|
||||
pad_h = config.img_height - img_data.shape[0]
|
||||
pad_w = config.img_width - img_data.shape[1]
|
||||
assert ((pad_h >= 0) and (pad_w >= 0))
|
||||
|
||||
pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype)
|
||||
pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data
|
||||
|
||||
img_shape = (config.img_height, config.img_width, 1.0)
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
|
||||
return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def rescale_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, config):
|
||||
"""rescale operation for image of eval"""
|
||||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
||||
if img_data.shape[0] > config.img_height:
|
||||
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_height), return_scale=True)
|
||||
scale_factor = scale_factor*scale_factor2
|
||||
|
||||
pad_h = config.img_height - img_data.shape[0]
|
||||
pad_w = config.img_width - img_data.shape[1]
|
||||
assert ((pad_h >= 0) and (pad_w >= 0))
|
||||
|
||||
pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype)
|
||||
pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data
|
||||
|
||||
img_shape = np.append(img_shape, (scale_factor, scale_factor))
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
|
||||
return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num, config):
|
||||
"""resize operation for image"""
|
||||
img_data = img
|
||||
img_data, w_scale, h_scale = mmcv.imresize(
|
||||
img_data, (config.img_width, config.img_height), return_scale=True)
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
img_shape = (config.img_height, config.img_width, 1.0)
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, config):
|
||||
"""resize operation for image of eval"""
|
||||
img_data = img
|
||||
img_data, w_scale, h_scale = mmcv.imresize(
|
||||
img_data, (config.img_width, config.img_height), return_scale=True)
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
img_shape = np.append(img_shape, (h_scale, w_scale))
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num, config):
|
||||
"""impad operation for image"""
|
||||
img_data = mmcv.impad(img, (config.img_height, config.img_width))
|
||||
img_data = img_data.astype(np.float32)
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""imnormalize operation for image"""
|
||||
img_data = mmcv.imnormalize(img, np.array([123.675, 116.28, 103.53]), np.array([58.395, 57.12, 57.375]), True)
|
||||
img_data = img_data.astype(np.float32)
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""flip operation for image"""
|
||||
img_data = img
|
||||
img_data = mmcv.imflip(img_data)
|
||||
flipped = gt_bboxes.copy()
|
||||
_, w, _ = img_data.shape
|
||||
|
||||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
|
||||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
|
||||
|
||||
return (img_data, img_shape, flipped, gt_label, gt_num)
|
||||
|
||||
|
||||
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""transpose operation for image"""
|
||||
img_data = img.transpose(2, 0, 1).copy()
|
||||
img_data = img_data.astype(np.float32)
|
||||
img_shape = img_shape.astype(np.float32)
|
||||
gt_bboxes = gt_bboxes.astype(np.float32)
|
||||
gt_label = gt_label.astype(np.int32)
|
||||
gt_num = gt_num.astype(np.bool)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""photo crop operation for image"""
|
||||
random_photo = PhotoMetricDistortion()
|
||||
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""expand operation for image"""
|
||||
expand = Expand()
|
||||
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
|
||||
|
||||
return (img, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def preprocess_fn(image, box, is_training, config):
|
||||
"""Preprocess function for dataset."""
|
||||
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
|
||||
image_shape = image_shape[:2]
|
||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
|
||||
|
||||
if config.keep_ratio:
|
||||
input_data = rescale_column_test(*input_data, config=config)
|
||||
else:
|
||||
input_data = resize_column_test(*input_data, config=config)
|
||||
input_data = imnormalize_column(*input_data)
|
||||
|
||||
output_data = transpose_column(*input_data)
|
||||
return output_data
|
||||
|
||||
def _data_aug(image, box, is_training):
|
||||
"""Data augmentation function."""
|
||||
image_bgr = image.copy()
|
||||
image_bgr[:, :, 0] = image[:, :, 2]
|
||||
image_bgr[:, :, 1] = image[:, :, 1]
|
||||
image_bgr[:, :, 2] = image[:, :, 0]
|
||||
image_shape = image_bgr.shape[:2]
|
||||
gt_box = box[:, :4]
|
||||
gt_label = box[:, 4]
|
||||
gt_iscrowd = box[:, 5]
|
||||
|
||||
pad_max_number = 128
|
||||
gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant", constant_values=0)
|
||||
gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1)
|
||||
gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=1)
|
||||
gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32)
|
||||
|
||||
if not is_training:
|
||||
return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)
|
||||
|
||||
flip = (np.random.rand() < config.flip_ratio)
|
||||
expand = (np.random.rand() < config.expand_ratio)
|
||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
|
||||
|
||||
if expand:
|
||||
input_data = expand_column(*input_data)
|
||||
if config.keep_ratio:
|
||||
input_data = rescale_column(*input_data, config=config)
|
||||
else:
|
||||
input_data = resize_column(*input_data, config=config)
|
||||
input_data = imnormalize_column(*input_data)
|
||||
if flip:
|
||||
input_data = flip_column(*input_data)
|
||||
|
||||
output_data = transpose_column(*input_data)
|
||||
return output_data
|
||||
|
||||
return _data_aug(image, box, is_training)
|
||||
|
||||
|
||||
def create_coco_label(is_training, config):
|
||||
"""Get image path and annotation from COCO."""
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
coco_root = config.coco_root
|
||||
data_type = config.val_data_type
|
||||
if is_training:
|
||||
data_type = config.train_data_type
|
||||
|
||||
# Classes need to train or test.
|
||||
train_cls = config.coco_classes
|
||||
train_cls_dict = {}
|
||||
for i, cls in enumerate(train_cls):
|
||||
train_cls_dict[cls] = i
|
||||
|
||||
anno_json = os.path.join(coco_root, config.instance_set.format(data_type))
|
||||
|
||||
coco = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco.loadCats(coco.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["id"]] = cat["name"]
|
||||
|
||||
image_ids = coco.getImgIds()
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
|
||||
for img_id in image_ids:
|
||||
image_info = coco.loadImgs(img_id)
|
||||
file_name = image_info[0]["file_name"]
|
||||
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = coco.loadAnns(anno_ids)
|
||||
# image_path = os.path.join(coco_root, data_type, file_name)
|
||||
image_path = os.path.join(coco_root, file_name)
|
||||
annos = []
|
||||
for label in anno:
|
||||
bbox = label["bbox"]
|
||||
class_name = classs_dict[label["category_id"]]
|
||||
if class_name in train_cls:
|
||||
x1, x2 = bbox[0], bbox[0] + bbox[2]
|
||||
y1, y2 = bbox[1], bbox[1] + bbox[3]
|
||||
annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])])
|
||||
|
||||
image_files.append(image_path)
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
else:
|
||||
image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
|
||||
|
||||
return image_files, image_anno_dict
|
||||
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
annos.append(anno)
|
||||
return annos
|
||||
|
||||
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
if not os.path.isdir(image_dir):
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
if not os.path.isfile(anno_path):
|
||||
raise RuntimeError("Annotation file is not valid.")
|
||||
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line_str = line.decode("utf-8").strip()
|
||||
line_split = str(line_str).split(' ')
|
||||
file_name = line_split[0]
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
if os.path.isfile(image_path):
|
||||
image_anno_dict[image_path] = anno_parser(line_split[1:])
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
|
||||
def data_to_mindrecord_byte_image(config, dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
if dataset == "coco":
|
||||
image_files, image_anno_dict = create_coco_label(is_training, config=config)
|
||||
else:
|
||||
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)
|
||||
|
||||
fasterrcnn_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 6]},
|
||||
}
|
||||
writer.add_schema(fasterrcnn_json, "fasterrcnn_json")
|
||||
|
||||
for image_name in image_files:
|
||||
with open(image_name, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def create_fasterrcnn_dataset(config, mindrecord_file, batch_size=2, device_num=1, rank_id=0, is_training=True,
|
||||
num_parallel_workers=8, python_multiprocessing=False):
|
||||
"""Create FasterRcnn dataset with MindDataset."""
|
||||
cv2.setNumThreads(0)
|
||||
de.config.set_prefetch_size(8)
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
|
||||
num_parallel_workers=4, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(input_columns=["image"], operations=decode)
|
||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training, config=config))
|
||||
|
||||
if is_training:
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
operations=compose_map_func, python_multiprocessing=python_multiprocessing,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
else:
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
operations=compose_map_func,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for fasterrcnn"""
|
||||
import math
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(config, steps_per_epoch):
|
||||
"""dynamic learning rate generator"""
|
||||
base_lr = config.base_lr
|
||||
total_steps = steps_per_epoch * (config.epoch_size + 1)
|
||||
warmup_steps = int(config.warmup_step)
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
|
||||
return lr
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '5')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn training network wrapper."""
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1, rank_id=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.count = 0
|
||||
self.loss_sum = 0
|
||||
self.rank_id = rank_id
|
||||
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = time.time()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
self.count += 1
|
||||
self.loss_sum += float(loss)
|
||||
|
||||
if self.count >= 1:
|
||||
global time_stamp_first
|
||||
time_stamp_current = time.time()
|
||||
total_loss = self.loss_sum / self.count
|
||||
|
||||
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
|
||||
loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" %
|
||||
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
total_loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
self.count = 0
|
||||
self.loss_sum = 0
|
||||
|
||||
|
||||
class LossNet(nn.Cell):
|
||||
"""FasterRcnn loss method"""
|
||||
def construct(self, x1, x2, x3, x4, x5, x6):
|
||||
return x1 + x2
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to compute loss.
|
||||
|
||||
Args:
|
||||
backbone (Cell): The target network to wrap.
|
||||
loss_fn (Cell): The loss function used to compute loss.
|
||||
"""
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
||||
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
||||
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""
|
||||
Get the backbone network.
|
||||
|
||||
Returns:
|
||||
Cell, return backbone network.
|
||||
"""
|
||||
return self._backbone
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Network training package class.
|
||||
|
||||
Append an optimizer to the training network after that the construct function
|
||||
can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default value is 1.0.
|
||||
reduce_flag (bool): The reduce flag. Default value is False.
|
||||
mean (bool): Allreduce method. Default value is False.
|
||||
degree (int): Device number. Default value is None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
|
||||
self.reduce_flag = reduce_flag
|
||||
if reduce_flag:
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
||||
weights = self.weights
|
||||
loss = self.network(x, img_shape, gt_bboxe, gt_label, gt_num)
|
||||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
|
||||
if self.reduce_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,225 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""coco eval for fasterrcnn"""
|
||||
import json
|
||||
import numpy as np
|
||||
import mmcv
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
_init_value = np.array(0.0)
|
||||
summary_init = {
|
||||
'Precision/mAP': _init_value,
|
||||
'Precision/mAP@.50IOU': _init_value,
|
||||
'Precision/mAP@.75IOU': _init_value,
|
||||
'Precision/mAP (small)': _init_value,
|
||||
'Precision/mAP (medium)': _init_value,
|
||||
'Precision/mAP (large)': _init_value,
|
||||
'Recall/AR@1': _init_value,
|
||||
'Recall/AR@10': _init_value,
|
||||
'Recall/AR@100': _init_value,
|
||||
'Recall/AR@100 (small)': _init_value,
|
||||
'Recall/AR@100 (medium)': _init_value,
|
||||
'Recall/AR@100 (large)': _init_value,
|
||||
}
|
||||
|
||||
|
||||
def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False):
|
||||
"""coco eval for fasterrcnn"""
|
||||
anns = json.load(open(result_files['bbox']))
|
||||
if not anns:
|
||||
return summary_init
|
||||
|
||||
if mmcv.is_str(coco):
|
||||
coco = COCO(coco)
|
||||
assert isinstance(coco, COCO)
|
||||
|
||||
for res_type in result_types:
|
||||
result_file = result_files[res_type]
|
||||
assert result_file.endswith('.json')
|
||||
|
||||
coco_dets = coco.loadRes(result_file)
|
||||
gt_img_ids = coco.getImgIds()
|
||||
det_img_ids = coco_dets.getImgIds()
|
||||
iou_type = 'bbox' if res_type == 'proposal' else res_type
|
||||
cocoEval = COCOeval(coco, coco_dets, iou_type)
|
||||
if res_type == 'proposal':
|
||||
cocoEval.params.useCats = 0
|
||||
cocoEval.params.maxDets = list(max_dets)
|
||||
|
||||
tgt_ids = gt_img_ids if not single_result else det_img_ids
|
||||
|
||||
if single_result:
|
||||
res_dict = dict()
|
||||
for id_i in tgt_ids:
|
||||
cocoEval = COCOeval(coco, coco_dets, iou_type)
|
||||
if res_type == 'proposal':
|
||||
cocoEval.params.useCats = 0
|
||||
cocoEval.params.maxDets = list(max_dets)
|
||||
|
||||
cocoEval.params.imgIds = [id_i]
|
||||
cocoEval.evaluate()
|
||||
cocoEval.accumulate()
|
||||
cocoEval.summarize()
|
||||
res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]})
|
||||
|
||||
cocoEval = COCOeval(coco, coco_dets, iou_type)
|
||||
if res_type == 'proposal':
|
||||
cocoEval.params.useCats = 0
|
||||
cocoEval.params.maxDets = list(max_dets)
|
||||
|
||||
cocoEval.params.imgIds = tgt_ids
|
||||
cocoEval.evaluate()
|
||||
cocoEval.accumulate()
|
||||
cocoEval.summarize()
|
||||
|
||||
summary_metrics = {
|
||||
'Precision/mAP': cocoEval.stats[0],
|
||||
'Precision/mAP@.50IOU': cocoEval.stats[1],
|
||||
'Precision/mAP@.75IOU': cocoEval.stats[2],
|
||||
'Precision/mAP (small)': cocoEval.stats[3],
|
||||
'Precision/mAP (medium)': cocoEval.stats[4],
|
||||
'Precision/mAP (large)': cocoEval.stats[5],
|
||||
'Recall/AR@1': cocoEval.stats[6],
|
||||
'Recall/AR@10': cocoEval.stats[7],
|
||||
'Recall/AR@100': cocoEval.stats[8],
|
||||
'Recall/AR@100 (small)': cocoEval.stats[9],
|
||||
'Recall/AR@100 (medium)': cocoEval.stats[10],
|
||||
'Recall/AR@100 (large)': cocoEval.stats[11],
|
||||
}
|
||||
|
||||
return summary_metrics
|
||||
|
||||
|
||||
def xyxy2xywh(bbox):
|
||||
_bbox = bbox.tolist()
|
||||
return [
|
||||
_bbox[0],
|
||||
_bbox[1],
|
||||
_bbox[2] - _bbox[0] + 1,
|
||||
_bbox[3] - _bbox[1] + 1,
|
||||
]
|
||||
|
||||
def bbox2result_1image(bboxes, labels, num_classes):
|
||||
"""Convert detection results to a list of numpy arrays.
|
||||
|
||||
Args:
|
||||
bboxes (Tensor): shape (n, 5)
|
||||
labels (Tensor): shape (n, )
|
||||
num_classes (int): class number, including background class
|
||||
|
||||
Returns:
|
||||
list(ndarray): bbox results of each class
|
||||
"""
|
||||
if bboxes.shape[0] == 0:
|
||||
result = [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)]
|
||||
else:
|
||||
result = [bboxes[labels == i, :] for i in range(num_classes - 1)]
|
||||
return result
|
||||
|
||||
def proposal2json(dataset, results):
|
||||
"""convert proposal to json mode"""
|
||||
img_ids = dataset.getImgIds()
|
||||
json_results = []
|
||||
dataset_len = dataset.get_dataset_size()*2
|
||||
for idx in range(dataset_len):
|
||||
img_id = img_ids[idx]
|
||||
bboxes = results[idx]
|
||||
for i in range(bboxes.shape[0]):
|
||||
data = dict()
|
||||
data['image_id'] = img_id
|
||||
data['bbox'] = xyxy2xywh(bboxes[i])
|
||||
data['score'] = float(bboxes[i][4])
|
||||
data['category_id'] = 1
|
||||
json_results.append(data)
|
||||
return json_results
|
||||
|
||||
def det2json(dataset, results):
|
||||
"""convert det to json mode"""
|
||||
cat_ids = dataset.getCatIds()
|
||||
img_ids = dataset.getImgIds()
|
||||
json_results = []
|
||||
dataset_len = len(img_ids)
|
||||
for idx in range(dataset_len):
|
||||
img_id = img_ids[idx]
|
||||
if idx == len(results): break
|
||||
result = results[idx]
|
||||
for label, result_label in enumerate(result):
|
||||
bboxes = result_label
|
||||
for i in range(bboxes.shape[0]):
|
||||
data = dict()
|
||||
data['image_id'] = img_id
|
||||
data['bbox'] = xyxy2xywh(bboxes[i])
|
||||
data['score'] = float(bboxes[i][4])
|
||||
data['category_id'] = cat_ids[label]
|
||||
json_results.append(data)
|
||||
return json_results
|
||||
|
||||
def segm2json(dataset, results):
|
||||
"""convert segm to json mode"""
|
||||
bbox_json_results = []
|
||||
segm_json_results = []
|
||||
for idx in range(len(dataset)):
|
||||
img_id = dataset.img_ids[idx]
|
||||
det, seg = results[idx]
|
||||
for label, det_label in enumerate(det):
|
||||
# bbox results
|
||||
bboxes = det_label
|
||||
for i in range(bboxes.shape[0]):
|
||||
data = dict()
|
||||
data['image_id'] = img_id
|
||||
data['bbox'] = xyxy2xywh(bboxes[i])
|
||||
data['score'] = float(bboxes[i][4])
|
||||
data['category_id'] = dataset.cat_ids[label]
|
||||
bbox_json_results.append(data)
|
||||
|
||||
if len(seg) == 2:
|
||||
segms = seg[0][label]
|
||||
mask_score = seg[1][label]
|
||||
else:
|
||||
segms = seg[label]
|
||||
mask_score = [bbox[4] for bbox in bboxes]
|
||||
for i in range(bboxes.shape[0]):
|
||||
data = dict()
|
||||
data['image_id'] = img_id
|
||||
data['score'] = float(mask_score[i])
|
||||
data['category_id'] = dataset.cat_ids[label]
|
||||
segms[i]['counts'] = segms[i]['counts'].decode()
|
||||
data['segmentation'] = segms[i]
|
||||
segm_json_results.append(data)
|
||||
return bbox_json_results, segm_json_results
|
||||
|
||||
def results2json(dataset, results, out_file):
|
||||
"""convert result convert to json mode"""
|
||||
result_files = dict()
|
||||
if isinstance(results[0], list):
|
||||
json_results = det2json(dataset, results)
|
||||
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
|
||||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
|
||||
mmcv.dump(json_results, result_files['bbox'])
|
||||
elif isinstance(results[0], tuple):
|
||||
json_results = segm2json(dataset, results)
|
||||
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
|
||||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
|
||||
result_files['segm'] = '{}.{}.json'.format(out_file, 'segm')
|
||||
mmcv.dump(json_results[0], result_files['bbox'])
|
||||
mmcv.dump(json_results[1], result_files['segm'])
|
||||
elif isinstance(results[0], np.ndarray):
|
||||
json_results = proposal2json(dataset, results)
|
||||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal')
|
||||
mmcv.dump(json_results, result_files['proposal'])
|
||||
else:
|
||||
raise TypeError('invalid type of results')
|
||||
return result_files
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""train FasterRcnn and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor, Parameter
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore.nn import SGD
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
||||
from src.dataset import create_fasterrcnn_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.model_utils.config import config
|
||||
from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet
|
||||
|
||||
set_seed(1)
|
||||
|
||||
device_target = config.device_target
|
||||
server_mode = config.server_mode
|
||||
ms_role = config.ms_role
|
||||
worker_num = config.worker_num
|
||||
server_num = config.server_num
|
||||
scheduler_ip = config.scheduler_ip
|
||||
scheduler_port = config.scheduler_port
|
||||
fl_server_port = config.fl_server_port
|
||||
start_fl_job_threshold = config.start_fl_job_threshold
|
||||
start_fl_job_time_window = config.start_fl_job_time_window
|
||||
update_model_ratio = config.update_model_ratio
|
||||
update_model_time_window = config.update_model_time_window
|
||||
fl_name = config.fl_name
|
||||
fl_iteration_num = config.fl_iteration_num
|
||||
client_epoch_num = config.client_epoch_num
|
||||
client_batch_size = config.client_batch_size
|
||||
client_learning_rate = config.client_learning_rate
|
||||
worker_step_num_per_iteration = config.worker_step_num_per_iteration
|
||||
scheduler_manage_port = config.scheduler_manage_port
|
||||
config_file_path = config.config_file_path
|
||||
encrypt_type = config.encrypt_type
|
||||
|
||||
user_id = config.user_id
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
"server_mode": server_mode,
|
||||
"ms_role": ms_role,
|
||||
"worker_num": worker_num,
|
||||
"server_num": server_num,
|
||||
"scheduler_ip": scheduler_ip,
|
||||
"scheduler_port": scheduler_port,
|
||||
"fl_server_port": fl_server_port,
|
||||
"start_fl_job_threshold": start_fl_job_threshold,
|
||||
"start_fl_job_time_window": start_fl_job_time_window,
|
||||
"update_model_ratio": update_model_ratio,
|
||||
"update_model_time_window": update_model_time_window,
|
||||
"fl_name": fl_name,
|
||||
"fl_iteration_num": fl_iteration_num,
|
||||
"client_epoch_num": client_epoch_num,
|
||||
"client_batch_size": client_batch_size,
|
||||
"client_learning_rate": client_learning_rate,
|
||||
"worker_step_num_per_iteration": worker_step_num_per_iteration,
|
||||
"scheduler_manage_port": scheduler_manage_port,
|
||||
"config_file_path": config_file_path,
|
||||
"encrypt_type": encrypt_type
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||
context.set_fl_context(**ctx)
|
||||
# print(**ctx, flush=True)
|
||||
# context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=get_device_id())
|
||||
# context.set_context(enable_graph_kernel=True)
|
||||
rank = 0
|
||||
device_num = 1
|
||||
user = "mindrecord_" + str(user_id)
|
||||
|
||||
def train_fasterrcnn_():
|
||||
""" train_fasterrcnn_ """
|
||||
print("Start create dataset!", flush=True)
|
||||
|
||||
# It will generate mindrecord file in config.mindrecord_dir,
|
||||
# and the file name is FasterRcnn.mindrecord0, 1, ... file_num.
|
||||
prefix = "FasterRcnn.mindrecord"
|
||||
mindrecord_dir = config.dataset_path
|
||||
mindrecord_file = os.path.join(mindrecord_dir, user, prefix)
|
||||
print("CHECKING MINDRECORD FILES ...", mindrecord_file, flush=True)
|
||||
|
||||
if rank == 0 and not os.path.exists(mindrecord_file):
|
||||
print("image_dir or anno_path not exits.", flush=True)
|
||||
|
||||
while not os.path.exists(mindrecord_file + ".db"):
|
||||
time.sleep(5)
|
||||
|
||||
print("CHECKING MINDRECORD FILES DONE!", flush=True)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
|
||||
dataset = create_fasterrcnn_dataset(config, mindrecord_file, batch_size=config.batch_size,
|
||||
device_num=device_num, rank_id=rank,
|
||||
num_parallel_workers=config.num_parallel_workers,
|
||||
python_multiprocessing=config.python_multiprocessing)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!", flush=True)
|
||||
|
||||
return dataset_size, dataset
|
||||
|
||||
class StartFLJob(nn.Cell):
|
||||
def __init__(self, data_size):
|
||||
super(StartFLJob, self).__init__()
|
||||
self.start_fl_job = P.StartFLJob(data_size)
|
||||
|
||||
def construct(self):
|
||||
return self.start_fl_job()
|
||||
|
||||
class UpdateAndGetModel(nn.Cell):
|
||||
def __init__(self, weights):
|
||||
super(UpdateAndGetModel, self).__init__()
|
||||
self.update_model = P.UpdateModel()
|
||||
self.get_model = P.GetModel()
|
||||
self.weights = weights
|
||||
|
||||
def construct(self):
|
||||
self.update_model(self.weights)
|
||||
get_model = self.get_model(self.weights)
|
||||
return get_model
|
||||
|
||||
def train():
|
||||
""" train_fasterrcnn """
|
||||
dataset_size, dataset = train_fasterrcnn_()
|
||||
net = Faster_Rcnn_Resnet(config=config)
|
||||
net = net.set_train()
|
||||
|
||||
load_path = config.pre_trained
|
||||
# load_path = ""
|
||||
if load_path != "":
|
||||
param_dict = load_checkpoint(load_path)
|
||||
|
||||
key_mapping = {'down_sample_layer.1.beta': 'bn_down_sample.beta',
|
||||
'down_sample_layer.1.gamma': 'bn_down_sample.gamma',
|
||||
'down_sample_layer.0.weight': 'conv_down_sample.weight',
|
||||
'down_sample_layer.1.moving_mean': 'bn_down_sample.moving_mean',
|
||||
'down_sample_layer.1.moving_variance': 'bn_down_sample.moving_variance',
|
||||
}
|
||||
for oldkey in list(param_dict.keys()):
|
||||
if not oldkey.startswith(('backbone', 'end_point', 'global_step', 'learning_rate', 'moments', 'momentum')):
|
||||
data = param_dict.pop(oldkey)
|
||||
newkey = 'backbone.' + oldkey
|
||||
param_dict[newkey] = data
|
||||
oldkey = newkey
|
||||
for k, v in key_mapping.items():
|
||||
if k in oldkey:
|
||||
newkey = oldkey.replace(k, v)
|
||||
param_dict[newkey] = param_dict.pop(oldkey)
|
||||
break
|
||||
for item in list(param_dict.keys()):
|
||||
if not item.startswith('backbone'):
|
||||
param_dict.pop(item)
|
||||
|
||||
for key, value in param_dict.items():
|
||||
tensor = value.asnumpy().astype(np.float32)
|
||||
param_dict[key] = Parameter(tensor, key)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
loss = LossNet()
|
||||
lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32)
|
||||
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
if config.run_distribute:
|
||||
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
|
||||
mean=True, degree=device_num)
|
||||
else:
|
||||
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossCallBack(rank_id=rank)
|
||||
cb = [time_cb, loss_cb]
|
||||
|
||||
model = Model(net)
|
||||
ckpt_path1 = os.path.join("ckpt", user)
|
||||
|
||||
os.makedirs(ckpt_path1)
|
||||
print("====================", config.client_epoch_num, fl_iteration_num, flush=True)
|
||||
for iter_num in range(fl_iteration_num):
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
start_fl_job = StartFLJob(dataset_size * config.batch_size)
|
||||
start_fl_job()
|
||||
model.train(config.client_epoch_num, dataset, callbacks=cb)
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
update_and_get_model = UpdateAndGetModel(opt.parameters)
|
||||
update_and_get_model()
|
||||
ckpt_name = user + "-fast-rcnn-" + str(iter_num) + "epoch.ckpt"
|
||||
ckpt_path = os.path.join(ckpt_path1, ckpt_name)
|
||||
save_checkpoint(net, ckpt_path)
|
||||
if __name__ == '__main__':
|
||||
train()
|
Loading…
Reference in New Issue