!22654 add fasterrcnn demo for cross-silo federated

Merge pull request !22654 from zhangqi/0831
This commit is contained in:
i-robot 2021-09-01 09:19:54 +00:00 committed by Gitee
commit 90e807f066
34 changed files with 5051 additions and 0 deletions

View File

@ -0,0 +1,7 @@
{
"recovery": {
"storage_type": 1,
"storage_file_path": "recovery.json"
}
}

View File

@ -0,0 +1,242 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: CPU
enable_profiling: False
# ==============================================================================
# config
img_width: 1280
img_height: 768
keep_ratio: True
flip_ratio: 0.5
expand_ratio: 1.0
# anchor
feature_shapes:
- [192, 320]
- [96, 160]
- [48, 80]
- [24, 40]
- [12, 20]
anchor_scales: [8]
anchor_ratios: [0.5, 1.0, 2.0]
anchor_strides: [4, 8, 16, 32, 64]
num_anchors: 3
# resnet
resnet_block: [3, 4, 6, 3]
resnet_in_channels: [64, 256, 512, 1024]
resnet_out_channels: [256, 512, 1024, 2048]
# fpn
fpn_in_channels: [256, 512, 1024, 2048]
fpn_out_channels: 256
fpn_num_outs: 5
# rpn
rpn_in_channels: 256
rpn_feat_channels: 256
rpn_loss_cls_weight: 1.0
rpn_loss_reg_weight: 1.0
rpn_cls_out_channels: 1
rpn_target_means: [0., 0., 0., 0.]
rpn_target_stds: [1.0, 1.0, 1.0, 1.0]
# bbox_assign_sampler
neg_iou_thr: 0.3
pos_iou_thr: 0.7
min_pos_iou: 0.3
num_bboxes: 245520
num_gts: 128
num_expected_neg: 256
num_expected_pos: 128
# proposal
activate_num_classes: 2
use_sigmoid_cls: True
# roi_align
roi_layer: {type: 'RoIAlign', out_size: 7, sample_num: 2}
roi_align_out_channels: 256
roi_align_featmap_strides: [4, 8, 16, 32]
roi_align_finest_scale: 56
roi_sample_num: 640
# bbox_assign_sampler_stage2
neg_iou_thr_stage2: 0.5
pos_iou_thr_stage2: 0.5
min_pos_iou_stage2: 0.5
num_bboxes_stage2: 2000
num_expected_pos_stage2: 128
num_expected_neg_stage2: 512
num_expected_total_stage2: 512
# rcnn
rcnn_num_layers: 2
rcnn_in_channels: 256
rcnn_fc_out_channels: 1024
rcnn_loss_cls_weight: 1
rcnn_loss_reg_weight: 1
rcnn_target_means: [0., 0., 0., 0.]
rcnn_target_stds: [0.1, 0.1, 0.2, 0.2]
# train proposal
rpn_proposal_nms_across_levels: False
rpn_proposal_nms_pre: 2000
rpn_proposal_nms_post: 2000
rpn_proposal_max_num: 2000
rpn_proposal_nms_thr: 0.7
rpn_proposal_min_bbox_size: 0
# test proposal
rpn_nms_across_levels: False
rpn_nms_pre: 1000
rpn_nms_post: 1000
rpn_max_num: 1000
rpn_nms_thr: 0.7
rpn_min_bbox_min_size: 0
test_score_thr: 0.05
test_iou_thr: 0.5
test_max_per_img: 100
test_batch_size: 2
rpn_head_use_sigmoid: True
rpn_head_weight: 1.0
# LR
base_lr: 0.01
warmup_step: 500
warmup_ratio: 0.0625
sgd_step: [8, 11]
sgd_momentum: 0.9
# train
batch_size: 4
loss_scale: 256
momentum: 0.91
weight_decay: 0.00001
epoch_size: 20
save_checkpoint: True
save_checkpoint_epochs: 5
keep_checkpoint_max: 20
save_checkpoint_path: "./"
# fl
server_mode: "FEDERATED_LEARNING"
ms_role: "MS_WORKER"
worker_num: 1
server_num: 1
scheduler_ip: "10.113.216.40"
scheduler_port: 8113
fl_server_port: 6666
start_fl_job_threshold: 1
start_fl_job_time_window: 3000
update_model_ratio: 1.0
update_model_time_window: 3000
fl_name: "Lenet"
fl_iteration_num: 25
client_epoch_num: 20
client_batch_size: 4
client_learning_rate: 0.01
worker_step_num_per_iteration: 65
scheduler_manage_port: 11202
config_file_path: ""
encrypt_type: "NOT_ENCRYPT"
dataset_path: ""
user_id: 0
# Number of threads used to process the dataset in parallel
num_parallel_workers: 8
# Parallelize Python operations with multiple worker processes
python_multiprocessing: True
mindrecord_dir: "./datasets/coco_split/split_1000/mindrecord_9"
coco_root: "./datasets/coco2017/coco2017"
train_data_type: "train2017"
val_data_type: "val2017"
#instance_set: "annotations/instances_{}.json"
instance_set: "./datasets/coco_split/split_1000/train_9.json"
coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
num_classes: 81
# train.py FasterRcnn training
run_distribute: False
dataset: "coco"
#pre_trained: "./cache/train/fasterrcnn/faster_rcnn-12_7393.ckpt"
pre_trained: "./resnet50_backbone.ckpt"
device_id: 0
device_num: 1
rank_id: 0
image_dir: ''
anno_path: ''
backbone: 'resnet_v1_50'
# eval.py FasterRcnn evaluation
ann_file: './cache/data/annotations/instances_val2017.json'
checkpoint_path: "./cache/train/fasterrcnn/faster_rcnn-12_7393.ckpt"
# export.py fasterrcnn_export
file_name: "faster_rcnn"
file_format: "AIR"
ckpt_file: "./cache/train/fasterrcnn/faster_rcnn-12_7393.ckpt"
# postprocess ("./src/config_50.yaml")
#ann_file: ''
result_path: ''
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
result_dir: "result files path."
label_dir: "image file path."
device_target: "device where the code will be implemented, default is Ascend"
file_name: "output file name."
dataset: "Dataset, either cifar10 or imagenet2012"
parameter_server: 'Run parameter server train'
width: 'input width'
height: 'input height'
enable_profiling: 'Whether enable profiling while training, default: False'
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
run_distribute: 'Run distribute, default is false.'
do_train: 'Do train or not, default is true.'
do_eval: 'Do eval or not, default is false.'
pre_trained: 'Pretrained checkpoint path'
device_id: 'Device id, default is 0.'
device_num: 'Use device nums, default is 1.'
rank_id: 'Rank id, default is 0.'
file_format: 'file format'
ann_file: "Ann file, default is val.json."
checkpoint_path: "Checkpoint file path."
ckpt_file: 'fasterrcnn ckpt file.'
result_path: "result file path."
backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152"
---
device_target: ['GPU', 'CPU', 'Ascend']
file_format: ["AIR", "ONNX", "MINDIR"]
dataset_name: ["cifar10", "imagenet2012"]

View File

@ -0,0 +1,30 @@
#copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Finish test_cross_silo_femnist.py case")
parser.add_argument("--scheduler_port", type=int, default=8113)
args, _ = parser.parse_known_args()
scheduler_port = args.scheduler_port
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
subprocess.call(['bash', '-c', cmd])

View File

@ -0,0 +1,55 @@
import os
import time
from mindspore.common import set_seed
from src.dataset import data_to_mindrecord_byte_image
from src.model_utils.config import config
set_seed(1)
rank = 0
device_num = 1
def generate_coco_mindrecord():
""" train_fasterrcnn_ """
# It will generate mindrecord file in config.mindrecord_dir,
# and the file name is FasterRcnn.mindrecord0, 1, ... file_num.
prefix = "FasterRcnn.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
print("CHECKING MINDRECORD FILES ...")
if rank == 0 and not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if config.dataset == "coco":
if os.path.isdir(config.coco_root):
if not os.path.exists(config.coco_root):
print("Please make sure config:coco_root is valid.")
raise ValueError(config.coco_root)
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image(config, "coco", True, prefix, 1)
# data_to_mindrecord_byte_image(config, "coco", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
else:
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
if not os.path.exists(config.image_dir):
print("Please make sure config:image_dir is valid.")
raise ValueError(config.image_dir)
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image(config, "other", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("image_dir or anno_path not exits.")
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
print("CHECKING MINDRECORD FILES DONE!")
if __name__ == '__main__':
generate_coco_mindrecord()

View File

@ -0,0 +1,28 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.model_utils.config import config
if config.backbone in ("resnet_v1.5_50", "resnet_v1_101", "resnet_v1_152"):
from src.FasterRcnn.faster_rcnn_resnet import Faster_Rcnn_Resnet
elif config.backbone == "resnet_v1_50":
from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet
def create_network(name, *args, **kwargs):
if name == "faster_rcnn":
return Faster_Rcnn_Resnet(config=config)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,3 @@
Cython
pycocotools
mmcv==0.2.14

View File

@ -0,0 +1,58 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_cross_silo_fasterrcnn.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=1)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
dataset_path = args.dataset_path
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
cmd_sched += "python ${self_path}/../test_fl_fasterrcnn.py"
cmd_sched += " --device_target=" + device_target
cmd_sched += " --server_mode=" + server_mode
cmd_sched += " --ms_role=MS_SCHED"
cmd_sched += " --worker_num=" + str(worker_num)
cmd_sched += " --server_num=" + str(server_num)
cmd_sched += " --config_file_path=" + str(config_file_path)
cmd_sched += " --scheduler_ip=" + scheduler_ip
cmd_sched += " --scheduler_port=" + str(scheduler_port)
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
cmd_sched += " --dataset_path=" + str(dataset_path)
cmd_sched += " --user_id=" + str(0)
cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched])

View File

@ -0,0 +1,102 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_cross_silo_fasterrcnn.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=1)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--dataset_path", type=str, default="")
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
dataset_path = args.dataset_path
if local_server_num == -1:
local_server_num = server_num
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
for i in range(local_server_num):
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
cmd_server += "python ${self_path}/../test_fl_fasterrcnn.py"
cmd_server += " --device_target=" + device_target
cmd_server += " --server_mode=" + server_mode
cmd_server += " --ms_role=MS_SERVER"
cmd_server += " --worker_num=" + str(worker_num)
cmd_server += " --server_num=" + str(server_num)
cmd_server += " --scheduler_ip=" + scheduler_ip
cmd_server += " --scheduler_port=" + str(scheduler_port)
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --config_file_path=" + str(config_file_path)
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --dataset_path=" + str(dataset_path)
cmd_server += " --user_id=" + str(0)
cmd_server += " > server.log 2>&1 &"
import time
time.sleep(0.3)
subprocess.call(['bash', '-c', cmd_server])

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_cross_silo_fasterrcnn.py case")
parser.add_argument("--device_target", type=str, default="GPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=1)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--local_worker_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
worker_step_num_per_iteration = args.worker_step_num_per_iteration
local_worker_num = args.local_worker_num
config_file_path = args.config_file_path
dataset_path = args.dataset_path
if local_worker_num == -1:
local_worker_num = worker_num
assert local_worker_num <= worker_num, "The local worker number should not be bigger than total worker number."
for i in range(local_worker_num):
cmd_worker = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_worker += "rm -rf ${execute_path}/worker_" + str(i) + "/ &&"
cmd_worker += "mkdir ${execute_path}/worker_" + str(i) + "/ &&"
cmd_worker += "cd ${execute_path}/worker_" + str(i) + "/ || exit && export GLOG_v=1 &&"
cmd_worker += "export CUDA_VISIBLE_DEVICES=" + str(i+4) + "&&"
cmd_worker += "python ${self_path}/../test_fl_fasterrcnn.py"
cmd_worker += " --device_target=" + device_target
cmd_worker += " --server_mode=" + server_mode
cmd_worker += " --ms_role=MS_WORKER"
cmd_worker += " --worker_num=" + str(worker_num)
cmd_worker += " --server_num=" + str(server_num)
cmd_worker += " --scheduler_ip=" + scheduler_ip
cmd_worker += " --scheduler_port=" + str(scheduler_port)
cmd_worker += " --config_file_path=" + str(config_file_path)
cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_worker += " --client_epoch_num=" + str(client_epoch_num)
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
cmd_worker += " --dataset_path=" + str(dataset_path)
cmd_worker += " --user_id=" + str(i)
cmd_worker += " > worker.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_worker])

View File

@ -0,0 +1,32 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn Init."""
from .resnet import ResNetFea, ResidualBlockUsing
from .resnet50v1 import ResidualBlockUsing_V1
from .bbox_assign_sample import BboxAssignSample
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
from .proposal_generator import Proposal
from .rcnn import Rcnn
from .rpn import RPN
from .roi_align import SingleRoIExtractor
from .anchor_generator import AnchorGenerator
__all__ = [
"ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn",
"FeatPyramidNeck", "Proposal", "Rcnn",
"RPN", "SingleRoIExtractor", "AnchorGenerator", "ResidualBlockUsing", "ResidualBlockUsing_V1"
]

View File

@ -0,0 +1,84 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn anchor generator."""
import numpy as np
class AnchorGenerator():
"""Anchor generator for FasterRcnn."""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
"""Anchor generator init method."""
self.base_size = base_size
self.scales = np.array(scales)
self.ratios = np.array(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
def gen_base_anchors(self):
"""Generate a single anchor."""
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = np.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1)
hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1)
else:
ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1)
base_anchors = np.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
axis=-1).round()
return base_anchors
def _meshgrid(self, x, y, row_major=True):
"""Generate grid."""
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
yy = np.repeat(y, len(x))
if row_major:
return xx, yy
return yy, xx
def grid_anchors(self, featmap_size, stride=16):
"""Generate anchor list."""
base_anchors = self.base_anchors
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w) * stride
shift_y = np.arange(0, feat_h) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
shifts = shifts.astype(base_anchors.dtype)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.reshape(-1, 4)
return all_anchors

View File

@ -0,0 +1,165 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn positive and negative sample screening for RPN."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
class BboxAssignSample(nn.Cell):
"""
Bbox assigner and sampler definition.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
Examples:
BboxAssignSample(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSample, self).__init__()
cfg = config
self.dtype = np.float32
self.ms_type = mstype.float32
self.batch_size = batch_size
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, self.ms_type)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, self.ms_type)
self.min_pos_iou = Tensor(cfg.min_pos_iou, self.ms_type)
self.zero_thr = Tensor(0.0, self.ms_type)
self.num_bboxes = num_bboxes
self.num_gts = cfg.num_gts
self.num_expected_pos = cfg.num_expected_pos
self.num_expected_neg = cfg.num_expected_neg
self.add_gt_as_proposals = add_gt_as_proposals
if self.add_gt_as_proposals:
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
self.scatterNdUpdate = P.ScatterNdUpdate()
self.scatterNd = P.ScatterNd()
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.assigned_gt_inds = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype))
self.check_gt_one = Tensor(np.full((self.num_gts, 4), -1, dtype=self.dtype))
self.check_anchor_two = Tensor(np.full((self.num_bboxes, 4), -2, dtype=self.dtype))
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
self.less(max_overlaps_w_gt, self.neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
assigned_gt_inds4 = assigned_gt_inds3
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), self.ms_type)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
num_pos = self.cast(self.logicalnot(valid_pos_index), self.ms_type)
num_pos = self.sum_inds(num_pos, -1)
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
pos_bboxes_ = self.gatherND(bboxes, pos_index)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
total_index = self.concat((pos_index, neg_index))
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
labels_total, self.cast(label_weights_total, mstype.bool_)

View File

@ -0,0 +1,196 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn tpositive and negative sample screening for Rcnn."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
class BboxAssignSampleForRcnn(nn.Cell):
"""
Bbox assigner and sampler definition.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
Examples:
BboxAssignSampleForRcnn(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSampleForRcnn, self).__init__()
cfg = config
self.dtype = np.float32
self.ms_type = mstype.float32
self.batch_size = batch_size
self.neg_iou_thr = cfg.neg_iou_thr_stage2
self.pos_iou_thr = cfg.pos_iou_thr_stage2
self.min_pos_iou = cfg.min_pos_iou_stage2
self.num_gts = cfg.num_gts
self.num_bboxes = num_bboxes
self.num_expected_pos = cfg.num_expected_pos_stage2
self.num_expected_neg = cfg.num_expected_neg_stage2
self.num_expected_total = cfg.num_expected_total_stage2
self.add_gt_as_proposals = add_gt_as_proposals
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32))
self.add_gt_as_proposals_valid = Tensor(np.full(self.num_gts, self.add_gt_as_proposals, dtype=np.int32))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2))
self.concat_axis1 = P.Concat(axis=1)
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
# Check
self.check_gt_one = Tensor(np.full((self.num_gts, 4), -1, dtype=self.dtype))
self.check_anchor_two = Tensor(np.full((self.num_bboxes, 4), -2, dtype=self.dtype))
# Init tensor
self.assigned_gt_inds = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.gt_ignores = Tensor(np.full(self.num_gts, -1, dtype=np.int32))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=self.dtype))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
self.reshape_shape_pos = (self.num_expected_pos, 1)
self.reshape_shape_neg = (self.num_expected_neg, 1)
self.scalar_zero = Tensor(0.0, dtype=self.ms_type)
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=self.ms_type)
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=self.ms_type)
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=self.ms_type)
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), \
gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), \
bboxes, self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt,
self.scalar_zero),
self.less(max_overlaps_w_gt,
self.scalar_neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_ac_j = overlaps[j:j+1:1, ::]
temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou)
temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j))
pos_mask_j = self.logicaland(temp1, temp2)
assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores)
bboxes = self.concat((gt_bboxes_i, bboxes))
label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores)
label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid
assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5))
# Get pos index
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), self.ms_type)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), self.ms_type), -1)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
pos_index = self.reshape(pos_index, self.reshape_shape_pos)
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
pos_index = pos_index * valid_pos_index
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
# Get neg index
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
neg_index = self.reshape(neg_index, self.reshape_shape_neg)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
neg_index = neg_index * valid_neg_index
pos_bboxes_ = self.gatherND(bboxes, pos_index)
neg_bboxes_ = self.gatherND(bboxes, neg_index)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
total_bboxes = self.concat((pos_bboxes_, neg_bboxes_))
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask))
total_labels = self.concat((pos_gt_labels, self.labels_neg_mask))
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
total_mask = self.concat((valid_pos_index, valid_neg_index))
return total_bboxes, total_deltas, total_labels, total_mask

View File

@ -0,0 +1,488 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn based on ResNet."""
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from .resnet import ResNetFea, ResidualBlockUsing
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
from .proposal_generator import Proposal
from .rcnn import Rcnn
from .rpn import RPN
from .roi_align import SingleRoIExtractor
from .anchor_generator import AnchorGenerator
class Faster_Rcnn_Resnet(nn.Cell):
"""
FasterRcnn Network.
Note:
backbone = resnet
Returns:
Tuple, tuple of output tensor.
rpn_loss: Scalar, Total loss of RPN subnet.
rcnn_loss: Scalar, Total loss of RCNN subnet.
rpn_cls_loss: Scalar, Classification loss of RPN subnet.
rpn_reg_loss: Scalar, Regression loss of RPN subnet.
rcnn_cls_loss: Scalar, Classification loss of RCNN subnet.
rcnn_reg_loss: Scalar, Regression loss of RCNN subnet.
Examples:
net = Faster_Rcnn_Resnet()
"""
def __init__(self, config):
super(Faster_Rcnn_Resnet, self).__init__()
self.dtype = np.float32
self.ms_type = mstype.float32
self.train_batch_size = config.batch_size
self.num_classes = config.num_classes
self.anchor_scales = config.anchor_scales
self.anchor_ratios = config.anchor_ratios
self.anchor_strides = config.anchor_strides
self.target_means = tuple(config.rcnn_target_means)
self.target_stds = tuple(config.rcnn_target_stds)
# Anchor generator
anchor_base_sizes = None
self.anchor_base_sizes = list(
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
featmap_sizes = config.feature_shapes
assert len(featmap_sizes) == len(self.anchor_generators)
self.anchor_list = self.get_anchors(featmap_sizes)
# Backbone resnet
self.backbone = ResNetFea(ResidualBlockUsing,
config.resnet_block,
config.resnet_in_channels,
config.resnet_out_channels,
False)
# Fpn
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels,
config.fpn_out_channels,
config.fpn_num_outs)
# Rpn and rpn loss
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
self.rpn_with_loss = RPN(config,
self.train_batch_size,
config.rpn_in_channels,
config.rpn_feat_channels,
config.num_anchors,
config.rpn_cls_out_channels)
# Proposal
self.proposal_generator = Proposal(config,
self.train_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator.set_train_local(config, True)
self.proposal_generator_test = Proposal(config,
config.test_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False)
# Assign and sampler stage two
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size,
config.num_bboxes_stage2, True)
self.decode = P.BoundingBoxDecode(max_shape=(config.img_height, config.img_width), means=self.target_means, \
stds=self.target_stds)
# Roi
self.roi_init(config)
# Rcnn
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer.out_size * config.roi_layer.out_size,
self.train_batch_size, self.num_classes)
# Op declare
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.concat = P.Concat(axis=0)
self.concat_1 = P.Concat(axis=1)
self.concat_2 = P.Concat(axis=2)
self.reshape = P.Reshape()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
# Improve speed
self.concat_start = min(self.num_classes - 2, 55)
self.concat_end = (self.num_classes - 1)
# Test mode
self.test_mode_init(config)
# Init tensor
self.init_tensor(config)
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
def roi_init(self, config):
"""
Initialize roi from the config file
Args:
config (file): config file.
roi_layer (dict): Numbers of block in different layers.
roi_align_out_channels (int): Out channel in each layer.
config.roi_align_featmap_strides (list): featmap_strides in each layer.
roi_align_finest_scale (int): finest_scale in roi.
Examples:
self.roi_init(config)
"""
self.roi_align = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
self.train_batch_size,
config.roi_align_finest_scale)
self.roi_align.set_train_local(config, True)
self.roi_align_test = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
1,
config.roi_align_finest_scale)
self.roi_align_test.set_train_local(config, False)
def test_mode_init(self, config):
"""
Initialize test_mode from the config file.
Args:
config (file): config file.
test_batch_size (int): Size of test batch.
rpn_max_num (int): max num of rpn.
test_score_thresh (float): threshold of test score.
test_iou_thr (float): threshold of test iou.
Examples:
self.test_mode_init(config)
"""
self.test_batch_size = config.test_batch_size
self.split = P.Split(axis=0, output_num=self.test_batch_size)
self.split_shape = P.Split(axis=0, output_num=4)
self.split_scores = P.Split(axis=1, output_num=self.num_classes)
self.split_cls = P.Split(axis=0, output_num=self.num_classes-1)
self.tile = P.Tile()
self.gather = P.GatherNd()
self.rpn_max_num = config.rpn_max_num
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(self.dtype))
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
self.ones_mask, self.zeros_mask), axis=1))
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_score_thr)
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * 0)
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(self.dtype) * -1)
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_iou_thr)
self.test_max_per_img = config.test_max_per_img
self.nms_test = P.NMSWithMask(config.test_iou_thr)
self.softmax = P.Softmax(axis=1)
self.logicand = P.LogicalAnd()
self.oneslike = P.OnesLike()
self.test_topk = P.TopK(sorted=True)
self.test_num_proposal = self.test_batch_size * self.rpn_max_num
def init_tensor(self, config):
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
dtype=self.dtype) for i in range(self.train_batch_size)]
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=self.dtype) \
for i in range(self.test_batch_size)]
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
"""
construct the FasterRcnn Network.
Args:
img_data: input image data.
img_metas: meta label of img.
gt_bboxes (Tensor): get the value of bboxes.
gt_labels (Tensor): get the value of labels.
gt_valids (Tensor): get the valid part of bboxes.
Returns:
Tuple,tuple of output tensor
"""
x = self.backbone(img_data)
x = self.fpn_ncek(x)
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x,
img_metas,
self.anchor_list,
gt_bboxes,
self.gt_labels_stage1,
gt_valids)
if self.training:
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list)
else:
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
gt_labels = self.cast(gt_labels, mstype.int32)
gt_valids = self.cast(gt_valids, mstype.int32)
bboxes_tuple = ()
deltas_tuple = ()
labels_tuple = ()
mask_tuple = ()
if self.training:
for i in range(self.train_batch_size):
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_labels_i = self.cast(gt_labels_i, mstype.uint8)
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
gt_valids_i = self.cast(gt_valids_i, mstype.bool_)
bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i,
gt_labels_i,
proposal_mask[i],
proposal[i][::, 0:4:1],
gt_valids_i)
bboxes_tuple += (bboxes,)
deltas_tuple += (deltas,)
labels_tuple += (labels,)
mask_tuple += (mask,)
bbox_targets = self.concat(deltas_tuple)
rcnn_labels = self.concat(labels_tuple)
bbox_targets = F.stop_gradient(bbox_targets)
rcnn_labels = F.stop_gradient(rcnn_labels)
rcnn_labels = self.cast(rcnn_labels, mstype.int32)
else:
mask_tuple += proposal_mask
bbox_targets = proposal_mask
rcnn_labels = proposal_mask
for p_i in proposal:
bboxes_tuple += (p_i[::, 0:4:1],)
if self.training:
if self.train_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
else:
if self.test_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
if self.device_type == "Ascend":
bboxes_all = self.cast(bboxes_all, mstype.float16)
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
rois = self.cast(rois, mstype.float32)
rois = F.stop_gradient(rois)
if self.training:
roi_feats = self.roi_align(rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
else:
roi_feats = self.roi_align_test(rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
roi_feats = self.cast(roi_feats, self.ms_type)
rcnn_masks = self.concat(mask_tuple)
rcnn_masks = F.stop_gradient(rcnn_masks)
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats,
bbox_targets,
rcnn_labels,
rcnn_mask_squeeze)
output = ()
if self.training:
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss)
else:
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas)
return output
def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas):
"""Get the actual detection box."""
scores = self.softmax(cls_logits)
boxes_all = ()
for i in range(self.num_classes):
k = i * 4
reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1])
out_boxes_i = self.decode(rois, reg_logits_i)
boxes_all += (out_boxes_i,)
img_metas_all = self.split(img_metas)
scores_all = self.split(scores)
mask_all = self.split(self.cast(mask_logits, mstype.int32))
boxes_all_with_batchsize = ()
for i in range(self.test_batch_size):
scale = self.split_shape(self.squeeze(img_metas_all[i]))
scale_h = scale[2]
scale_w = scale[3]
boxes_tuple = ()
for j in range(self.num_classes):
boxes_tmp = self.split(boxes_all[j])
out_boxes_h = boxes_tmp[i] / scale_h
out_boxes_w = boxes_tmp[i] / scale_w
boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),)
boxes_all_with_batchsize += (boxes_tuple,)
output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all)
return output
def multiclass_nms(self, boxes_all, scores_all, mask_all):
"""Multiscale postprocessing."""
all_bboxes = ()
all_labels = ()
all_masks = ()
for i in range(self.test_batch_size):
bboxes = boxes_all[i]
scores = scores_all[i]
masks = self.cast(mask_all[i], mstype.bool_)
res_boxes_tuple = ()
res_labels_tuple = ()
res_masks_tuple = ()
for j in range(self.num_classes - 1):
k = j + 1
_cls_scores = scores[::, k:k + 1:1]
_bboxes = self.squeeze(bboxes[k])
_mask_o = self.reshape(masks, (self.rpn_max_num, 1))
cls_mask = self.greater(_cls_scores, self.test_score_thresh)
_mask = self.logicand(_mask_o, cls_mask)
_reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_)
_bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros)
_cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros)
__cls_scores = self.squeeze(_cls_scores)
scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num)
topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1))
scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1))
_bboxes_sorted = self.gather(_bboxes, topk_inds)
_mask_sorted = self.gather(_mask, topk_inds)
scores_sorted = self.tile(scores_sorted, (1, 4))
cls_dets = self.concat_1((_bboxes_sorted, scores_sorted))
cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5))
cls_dets, _index, _mask_nms = self.nms_test(cls_dets)
_index = self.reshape(_index, (self.rpn_max_num, 1))
_mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1))
_mask_n = self.gather(_mask_sorted, _index)
_mask_n = self.logicand(_mask_n, _mask_nms)
cls_labels = self.oneslike(_index) * j
res_boxes_tuple += (cls_dets,)
res_labels_tuple += (cls_labels,)
res_masks_tuple += (_mask_n,)
res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start])
res_labels_start = self.concat(res_labels_tuple[:self.concat_start])
res_masks_start = self.concat(res_masks_tuple[:self.concat_start])
res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end])
res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end])
res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end])
res_boxes = self.concat((res_boxes_start, res_boxes_end))
res_labels = self.concat((res_labels_start, res_labels_end))
res_masks = self.concat((res_masks_start, res_masks_end))
reshape_size = (self.num_classes - 1) * self.rpn_max_num
res_boxes = self.reshape(res_boxes, (1, reshape_size, 5))
res_labels = self.reshape(res_labels, (1, reshape_size, 1))
res_masks = self.reshape(res_masks, (1, reshape_size, 1))
all_bboxes += (res_boxes,)
all_labels += (res_labels,)
all_masks += (res_masks,)
all_bboxes = self.concat(all_bboxes)
all_labels = self.concat(all_labels)
all_masks = self.concat(all_masks)
return all_bboxes, all_labels, all_masks
def get_anchors(self, featmap_sizes):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = ()
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors += (Tensor(anchors.astype(self.dtype)),)
return multi_level_anchors
class FasterRcnn_Infer(nn.Cell):
def __init__(self, config):
super(FasterRcnn_Infer, self).__init__()
self.network = Faster_Rcnn_Resnet(config)
self.network.set_train(False)
def construct(self, img_data, img_metas):
output = self.network(img_data, img_metas, None, None, None)
return output

View File

@ -0,0 +1,488 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn based on ResNet50v1.0."""
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from .resnet import ResNetFea
from .resnet50v1 import ResidualBlockUsing_V1
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
from .proposal_generator import Proposal
from .rcnn import Rcnn
from .rpn import RPN
from .roi_align import SingleRoIExtractor
from .anchor_generator import AnchorGenerator
class Faster_Rcnn_Resnet(nn.Cell):
"""
FasterRcnn Network.
Note:
backbone = resnet
Returns:
Tuple, tuple of output tensor.
rpn_loss: Scalar, Total loss of RPN subnet.
rcnn_loss: Scalar, Total loss of RCNN subnet.
rpn_cls_loss: Scalar, Classification loss of RPN subnet.
rpn_reg_loss: Scalar, Regression loss of RPN subnet.
rcnn_cls_loss: Scalar, Classification loss of RCNN subnet.
rcnn_reg_loss: Scalar, Regression loss of RCNN subnet.
Examples:
net = Faster_Rcnn_Resnet()
"""
def __init__(self, config):
super(Faster_Rcnn_Resnet, self).__init__()
self.dtype = np.float32
self.ms_type = mstype.float32
self.train_batch_size = config.batch_size
self.num_classes = config.num_classes
self.anchor_scales = config.anchor_scales
self.anchor_ratios = config.anchor_ratios
self.anchor_strides = config.anchor_strides
self.target_means = tuple(config.rcnn_target_means)
self.target_stds = tuple(config.rcnn_target_stds)
# Anchor generator
anchor_base_sizes = None
self.anchor_base_sizes = list(
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
featmap_sizes = config.feature_shapes
assert len(featmap_sizes) == len(self.anchor_generators)
self.anchor_list = self.get_anchors(featmap_sizes)
# Backbone resnet
self.backbone = ResNetFea(ResidualBlockUsing_V1,
config.resnet_block,
config.resnet_in_channels,
config.resnet_out_channels,
False)
# Fpn
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels,
config.fpn_out_channels,
config.fpn_num_outs)
# Rpn and rpn loss
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
self.rpn_with_loss = RPN(config,
self.train_batch_size,
config.rpn_in_channels,
config.rpn_feat_channels,
config.num_anchors,
config.rpn_cls_out_channels)
# Proposal
self.proposal_generator = Proposal(config,
self.train_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator.set_train_local(config, True)
self.proposal_generator_test = Proposal(config,
config.test_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False)
# Assign and sampler stage two
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size,
config.num_bboxes_stage2, True)
self.decode = P.BoundingBoxDecode(max_shape=(config.img_height, config.img_width), means=self.target_means, \
stds=self.target_stds)
# Roi
self.roi_init(config)
# Rcnn
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer.out_size * config.roi_layer.out_size,
self.train_batch_size, self.num_classes)
# Op declare
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.concat = P.Concat(axis=0)
self.concat_1 = P.Concat(axis=1)
self.concat_2 = P.Concat(axis=2)
self.reshape = P.Reshape()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
# Improve speed
self.concat_start = min(self.num_classes - 2, 55)
self.concat_end = (self.num_classes - 1)
# Test mode
self.test_mode_init(config)
# Init tensor
self.init_tensor(config)
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
def roi_init(self, config):
"""
Initialize roi from the config file
Args:
config (file): config file.
roi_layer (dict): Numbers of block in different layers.
roi_align_out_channels (int): Out channel in each layer.
config.roi_align_featmap_strides (list): featmap_strides in each layer.
roi_align_finest_scale (int): finest_scale in roi.
Examples:
self.roi_init(config)
"""
self.roi_align = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
self.train_batch_size,
config.roi_align_finest_scale)
self.roi_align.set_train_local(config, True)
self.roi_align_test = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
1,
config.roi_align_finest_scale)
self.roi_align_test.set_train_local(config, False)
def test_mode_init(self, config):
"""
Initialize test_mode from the config file.
Args:
config (file): config file.
test_batch_size (int): Size of test batch.
rpn_max_num (int): max num of rpn.
test_score_thresh (float): threshold of test score.
test_iou_thr (float): threshold of test iou.
Examples:
self.test_mode_init(config)
"""
self.test_batch_size = config.test_batch_size
self.split = P.Split(axis=0, output_num=self.test_batch_size)
self.split_shape = P.Split(axis=0, output_num=4)
self.split_scores = P.Split(axis=1, output_num=self.num_classes)
self.split_cls = P.Split(axis=0, output_num=self.num_classes-1)
self.tile = P.Tile()
self.gather = P.GatherNd()
self.rpn_max_num = config.rpn_max_num
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(self.dtype))
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
self.ones_mask, self.zeros_mask), axis=1))
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_score_thr)
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * 0)
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(self.dtype) * -1)
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_iou_thr)
self.test_max_per_img = config.test_max_per_img
self.nms_test = P.NMSWithMask(config.test_iou_thr)
self.softmax = P.Softmax(axis=1)
self.logicand = P.LogicalAnd()
self.oneslike = P.OnesLike()
self.test_topk = P.TopK(sorted=True)
self.test_num_proposal = self.test_batch_size * self.rpn_max_num
def init_tensor(self, config):
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
dtype=self.dtype) for i in range(self.train_batch_size)]
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=self.dtype) \
for i in range(self.test_batch_size)]
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
"""
construct the FasterRcnn Network.
Args:
img_data: input image data.
img_metas: meta label of img.
gt_bboxes (Tensor): get the value of bboxes.
gt_labels (Tensor): get the value of labels.
gt_valids (Tensor): get the valid part of bboxes.
Returns:
Tuple,tuple of output tensor
"""
x = self.backbone(img_data)
x = self.fpn_ncek(x)
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x,
img_metas,
self.anchor_list,
gt_bboxes,
self.gt_labels_stage1,
gt_valids)
if self.training:
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list)
else:
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
gt_labels = self.cast(gt_labels, mstype.int32)
gt_valids = self.cast(gt_valids, mstype.int32)
bboxes_tuple = ()
deltas_tuple = ()
labels_tuple = ()
mask_tuple = ()
if self.training:
for i in range(self.train_batch_size):
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_labels_i = self.cast(gt_labels_i, mstype.uint8)
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
gt_valids_i = self.cast(gt_valids_i, mstype.bool_)
bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i,
gt_labels_i,
proposal_mask[i],
proposal[i][::, 0:4:1],
gt_valids_i)
bboxes_tuple += (bboxes,)
deltas_tuple += (deltas,)
labels_tuple += (labels,)
mask_tuple += (mask,)
bbox_targets = self.concat(deltas_tuple)
rcnn_labels = self.concat(labels_tuple)
bbox_targets = F.stop_gradient(bbox_targets)
rcnn_labels = F.stop_gradient(rcnn_labels)
rcnn_labels = self.cast(rcnn_labels, mstype.int32)
else:
mask_tuple += proposal_mask
bbox_targets = proposal_mask
rcnn_labels = proposal_mask
for p_i in proposal:
bboxes_tuple += (p_i[::, 0:4:1],)
if self.training:
if self.train_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
else:
if self.test_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
if self.device_type == "Ascend":
bboxes_all = self.cast(bboxes_all, mstype.float16)
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
rois = self.cast(rois, mstype.float32)
rois = F.stop_gradient(rois)
if self.training:
roi_feats = self.roi_align(rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
else:
roi_feats = self.roi_align_test(rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
roi_feats = self.cast(roi_feats, self.ms_type)
rcnn_masks = self.concat(mask_tuple)
rcnn_masks = F.stop_gradient(rcnn_masks)
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats,
bbox_targets,
rcnn_labels,
rcnn_mask_squeeze)
output = ()
if self.training:
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss)
else:
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas)
return output
def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas):
"""Get the actual detection box."""
scores = self.softmax(cls_logits)
boxes_all = ()
for i in range(self.num_classes):
k = i * 4
reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1])
out_boxes_i = self.decode(rois, reg_logits_i)
boxes_all += (out_boxes_i,)
img_metas_all = self.split(img_metas)
scores_all = self.split(scores)
mask_all = self.split(self.cast(mask_logits, mstype.int32))
boxes_all_with_batchsize = ()
for i in range(self.test_batch_size):
scale = self.split_shape(self.squeeze(img_metas_all[i]))
scale_h = scale[2]
scale_w = scale[3]
boxes_tuple = ()
for j in range(self.num_classes):
boxes_tmp = self.split(boxes_all[j])
out_boxes_h = boxes_tmp[i] / scale_h
out_boxes_w = boxes_tmp[i] / scale_w
boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),)
boxes_all_with_batchsize += (boxes_tuple,)
output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all)
return output
def multiclass_nms(self, boxes_all, scores_all, mask_all):
"""Multiscale postprocessing."""
all_bboxes = ()
all_labels = ()
all_masks = ()
for i in range(self.test_batch_size):
bboxes = boxes_all[i]
scores = scores_all[i]
masks = self.cast(mask_all[i], mstype.bool_)
res_boxes_tuple = ()
res_labels_tuple = ()
res_masks_tuple = ()
for j in range(self.num_classes - 1):
k = j + 1
_cls_scores = scores[::, k:k + 1:1]
_bboxes = self.squeeze(bboxes[k])
_mask_o = self.reshape(masks, (self.rpn_max_num, 1))
cls_mask = self.greater(_cls_scores, self.test_score_thresh)
_mask = self.logicand(_mask_o, cls_mask)
_reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_)
_bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros)
_cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros)
__cls_scores = self.squeeze(_cls_scores)
scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num)
topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1))
scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1))
_bboxes_sorted = self.gather(_bboxes, topk_inds)
_mask_sorted = self.gather(_mask, topk_inds)
scores_sorted = self.tile(scores_sorted, (1, 4))
cls_dets = self.concat_1((_bboxes_sorted, scores_sorted))
cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5))
cls_dets, _index, _mask_nms = self.nms_test(cls_dets)
_index = self.reshape(_index, (self.rpn_max_num, 1))
_mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1))
_mask_n = self.gather(_mask_sorted, _index)
_mask_n = self.logicand(_mask_n, _mask_nms)
cls_labels = self.oneslike(_index) * j
res_boxes_tuple += (cls_dets,)
res_labels_tuple += (cls_labels,)
res_masks_tuple += (_mask_n,)
res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start])
res_labels_start = self.concat(res_labels_tuple[:self.concat_start])
res_masks_start = self.concat(res_masks_tuple[:self.concat_start])
res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end])
res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end])
res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end])
res_boxes = self.concat((res_boxes_start, res_boxes_end))
res_labels = self.concat((res_labels_start, res_labels_end))
res_masks = self.concat((res_masks_start, res_masks_end))
reshape_size = (self.num_classes - 1) * self.rpn_max_num
res_boxes = self.reshape(res_boxes, (1, reshape_size, 5))
res_labels = self.reshape(res_labels, (1, reshape_size, 1))
res_masks = self.reshape(res_masks, (1, reshape_size, 1))
all_bboxes += (res_boxes,)
all_labels += (res_labels,)
all_masks += (res_masks,)
all_bboxes = self.concat(all_bboxes)
all_labels = self.concat(all_labels)
all_masks = self.concat(all_masks)
return all_bboxes, all_labels, all_masks
def get_anchors(self, featmap_sizes):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = ()
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors += (Tensor(anchors.astype(self.dtype)),)
return multi_level_anchors
class FasterRcnn_Infer(nn.Cell):
def __init__(self, config):
super(FasterRcnn_Infer, self).__init__()
self.network = Faster_Rcnn_Resnet(config)
self.network.set_train(False)
def construct(self, img_data, img_metas):
output = self.network(img_data, img_metas, None, None, None)
return output

View File

@ -0,0 +1,112 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn feature pyramid network."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
def bias_init_zeros(shape):
"""Bias init method."""
return Tensor(np.array(np.zeros(shape).astype(np.float32)))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32).to_tensor()
shape_bias = (out_channels,)
biass = bias_init_zeros(shape_bias)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass)
class FeatPyramidNeck(nn.Cell):
"""
Feature pyramid network cell, usually uses as network neck.
Applies the convolution on multiple, input feature maps
and output feature map with same channel size. if required num of
output larger then num of inputs, add extra maxpooling for further
downsampling;
Args:
in_channels (tuple) - Channel size of input feature maps.
out_channels (int) - Channel size output.
num_outs (int) - Num of output features.
Returns:
Tuple, with tensors of same channel size.
Examples:
neck = FeatPyramidNeck([100,200,300], 50, 4)
input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)),
dtype=np.float32) \
for i, c in enumerate(config.fpn_in_channels))
x = neck(input_data)
"""
def __init__(self,
in_channels,
out_channels,
num_outs):
super(FeatPyramidNeck, self).__init__()
self.num_outs = num_outs
self.in_channels = in_channels
self.fpn_layer = len(self.in_channels)
assert not self.num_outs < len(in_channels)
self.lateral_convs_list_ = []
self.fpn_convs_ = []
for _, channel in enumerate(in_channels):
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid')
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same')
self.lateral_convs_list_.append(l_conv)
self.fpn_convs_.append(fpn_conv)
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
self.interpolate1 = P.ResizeNearestNeighbor((48, 80))
self.interpolate2 = P.ResizeNearestNeighbor((96, 160))
self.interpolate3 = P.ResizeNearestNeighbor((192, 320))
self.maxpool = P.MaxPool(kernel_size=1, strides=2, pad_mode="same")
def construct(self, inputs):
x = ()
for i in range(self.fpn_layer):
x += (self.lateral_convs_list[i](inputs[i]),)
y = (x[3],)
y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),)
y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),)
y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),)
z = ()
for i in range(self.fpn_layer - 1, -1, -1):
z = z + (y[i],)
outs = ()
for i in range(self.fpn_layer):
outs = outs + (self.fpn_convs_list[i](z[i]),)
for i in range(self.num_outs - self.fpn_layer):
outs = outs + (self.maxpool(outs[3]),)
return outs

View File

@ -0,0 +1,198 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn proposal generator."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
class Proposal(nn.Cell):
"""
Proposal subnet.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_classes (int) - Class number.
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
Returns:
Tuple, tuple of output tensor,(proposal, mask).
Examples:
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
"""
def __init__(self,
config,
batch_size,
num_classes,
use_sigmoid_cls,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)
):
super(Proposal, self).__init__()
cfg = config
self.batch_size = batch_size
self.num_classes = num_classes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
self.activation = P.Sigmoid()
self.reshape_shape = (-1, 1)
else:
self.cls_out_channels = num_classes
self.activation = P.Softmax(axis=1)
self.reshape_shape = (-1, 2)
if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes))
self.num_pre = cfg.rpn_proposal_nms_pre
self.min_box_size = cfg.rpn_proposal_min_bbox_size
self.nms_thr = cfg.rpn_proposal_nms_thr
self.nms_post = cfg.rpn_proposal_nms_post
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
self.max_num = cfg.rpn_proposal_max_num
self.num_levels = cfg.fpn_num_outs
# Op Define
self.squeeze = P.Squeeze()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.feature_shapes = cfg.feature_shapes
self.transpose_shape = (1, 2, 0)
self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \
means=self.target_means, \
stds=self.target_stds)
self.nms = P.NMSWithMask(self.nms_thr)
self.concat_axis0 = P.Concat(axis=0)
self.concat_axis1 = P.Concat(axis=1)
self.split = P.Split(axis=1, output_num=5)
self.min = P.Minimum()
self.gatherND = P.GatherNd()
self.slice = P.Slice()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
self.tile = P.Tile()
self.set_train_local(config, training=True)
self.dtype = np.float32
self.ms_type = mstype.float32
self.multi_10 = Tensor(10.0, self.ms_type)
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
cfg = config
self.topK_stage1 = ()
self.topK_shape = ()
total_max_topk_input = 0
if not self.training_local:
self.num_pre = cfg.rpn_nms_pre
self.min_box_size = cfg.rpn_min_bbox_min_size
self.nms_thr = cfg.rpn_nms_thr
self.nms_post = cfg.rpn_nms_post
self.nms_across_levels = cfg.rpn_nms_across_levels
self.max_num = cfg.rpn_max_num
for shp in self.feature_shapes:
k_num = min(self.num_pre, (shp[0] * shp[1] * 3))
total_max_topk_input += k_num
self.topK_stage1 += (k_num,)
self.topK_shape += ((k_num, 1),)
self.topKv2 = P.TopK(sorted=True)
self.topK_shape_stage2 = (self.max_num, 1)
self.min_float_num = -65500.0
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float32))
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
proposals_tuple = ()
masks_tuple = ()
for img_id in range(self.batch_size):
cls_score_list = ()
bbox_pred_list = ()
for i in range(self.num_levels):
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::])
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::])
cls_score_list = cls_score_list + (rpn_cls_score_i,)
bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,)
proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list)
proposals_tuple += (proposals,)
masks_tuple += (masks,)
return proposals_tuple, masks_tuple
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
"""Get proposal boundingbox."""
mlvl_proposals = ()
mlvl_mask = ()
for idx in range(self.num_levels):
rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape)
rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape)
anchors = mlvl_anchors[idx]
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
rpn_cls_score = self.activation(rpn_cls_score)
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), self.ms_type)
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), self.ms_type)
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx])
topk_inds = self.reshape(topk_inds, self.topK_shape[idx])
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), self.ms_type)
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx])))
proposals, _, mask_valid = self.nms(proposals_decode)
mlvl_proposals = mlvl_proposals + (proposals,)
mlvl_mask = mlvl_mask + (mask_valid,)
proposals = self.concat_axis0(mlvl_proposals)
masks = self.concat_axis0(mlvl_mask)
_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, self.ms_type)
scores_using = self.select(masks, scores, topk_mask)
_, topk_inds = self.topKv2(scores_using, self.max_num)
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
proposals = self.gatherND(proposals, topk_inds)
masks = self.gatherND(masks, topk_inds)
return proposals, masks

View File

@ -0,0 +1,179 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn Rcnn network."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore import context
class DenseNoTranpose(nn.Cell):
"""Dense method"""
def __init__(self, input_channels, output_channels, weight_init):
super(DenseNoTranpose, self).__init__()
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float32))
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float32))
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
def construct(self, x):
if self.device_type == "Ascend":
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
output = self.bias_add(self.matmul(x, weight), self.bias)
else:
output = self.bias_add(self.matmul(x, self.weight), self.bias)
return output
class Rcnn(nn.Cell):
"""
Rcnn subnet.
Args:
config (dict) - Config.
representation_size (int) - Channels of shared dense.
batch_size (int) - Batchsize.
num_classes (int) - Class number.
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).
Returns:
Tuple, tuple of output tensor.
Examples:
Rcnn(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
"""
def __init__(self,
config,
representation_size,
batch_size,
num_classes,
target_means=(0., 0., 0., 0.),
target_stds=(0.1, 0.1, 0.2, 0.2)
):
super(Rcnn, self).__init__()
cfg = config
self.dtype = np.float32
self.ms_type = mstype.float32
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(self.dtype))
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(self.dtype))
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
self.target_means = target_means
self.target_stds = target_stds
self.num_classes = num_classes
self.in_channels = cfg.rcnn_in_channels
self.train_batch_size = batch_size
self.test_batch_size = cfg.test_batch_size
shape_0 = (self.rcnn_fc_out_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=self.ms_type).to_tensor()
shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=self.ms_type).to_tensor()
self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0)
self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1)
cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1],
dtype=self.ms_type).to_tensor()
reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1],
dtype=self.ms_type).to_tensor()
self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight)
self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight)
self.flatten = P.Flatten()
self.relu = P.ReLU()
self.logicaland = P.LogicalAnd()
self.loss_cls = P.SoftmaxCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(beta=1.0)
self.reshape = P.Reshape()
self.onehot = P.OneHot()
self.greater = P.Greater()
self.cast = P.Cast()
self.sum_loss = P.ReduceSum()
self.tile = P.Tile()
self.expandims = P.ExpandDims()
self.gather = P.GatherNd()
self.argmax = P.ArgMaxWithValue(axis=1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.value = Tensor(1.0, self.ms_type)
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size
rmv_first = np.ones((self.num_bboxes, self.num_classes))
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(self.dtype))
self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size
range_max = np.arange(self.num_bboxes_test).astype(np.int32)
self.range_max = Tensor(range_max)
def construct(self, featuremap, bbox_targets, labels, mask):
x = self.flatten(featuremap)
x = self.relu(self.shared_fc_0(x))
x = self.relu(self.shared_fc_1(x))
x_cls = self.cls_scores(x)
x_reg = self.reg_scores(x)
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)
out = (loss, loss_cls, loss_reg, loss_print)
else:
out = (x_cls, (x_cls / self.value), x_reg, x_cls)
return out
def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights):
"""Loss method."""
loss_print = ()
loss_cls, _ = self.loss_cls(cls_score, labels)
weights = self.cast(weights, self.ms_type)
loss_cls = loss_cls * weights
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
self.ms_type)
bbox_weights = bbox_weights * self.rmv_first_tensor
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
loss_reg = self.sum_loss(loss_reg, (2,))
loss_reg = loss_reg * bbox_weights
loss_reg = loss_reg / self.sum_loss(weights, (0,))
loss_reg = self.sum_loss(loss_reg, (0, 1))
loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg
loss_print += (loss_cls, loss_reg)
return loss, loss_cls, loss_reg, loss_print

View File

@ -0,0 +1,262 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Resnet backbone."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
def weight_init_ones(shape):
"""Weight init."""
return Tensor(np.full(shape, 0.01).astype(np.float32))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = weight_init_ones(shape)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
"""Batchnorm2D wrapper."""
dtype = np.float32
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
class ResNetFea(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
weights_update (bool): Weight update flag.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> False)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
weights_update=False):
super(ResNetFea, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
bn_training = False
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training)
self.relu = P.ReLU()
self.maxpool = P.MaxPool(kernel_size=3, strides=2, pad_mode="SAME")
self.weights_update = weights_update
if not self.weights_update:
self.conv1.weight.requires_grad = False
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=1,
training=bn_training,
weights_update=self.weights_update)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=2,
training=bn_training,
weights_update=True)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=2,
training=bn_training,
weights_update=True)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=2,
training=bn_training,
weights_update=True)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False):
"""Make block layer."""
layers = []
down_sample = False
if stride != 1 or in_channel != out_channel:
down_sample = True
resblk = block(in_channel,
out_channel,
stride=stride,
down_sample=down_sample,
training=training,
weights_update=weights_update)
layers.append(resblk)
for _ in range(1, layer_num):
resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update)
layers.append(resblk)
return nn.SequentialCell(layers)
def construct(self, x):
"""
construct the ResNet Network
Args:
x: input feature data.
Returns:
Tensor, output tensor.
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
identity = c2
if not self.weights_update:
identity = F.stop_gradient(c2)
c3 = self.layer2(identity)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
return identity, c3, c4, c5
class ResidualBlockUsing(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channels (int) - Input channel.
out_channels (int) - Output channel.
stride (int) - Stride size for the initial convolutional layer. Default: 1.
down_sample (bool) - If to do the downsample in block. Default: False.
momentum (float) - Momentum for batchnorm layer. Default: 0.1.
training (bool) - Training flag. Default: False.
weights_updata (bool) - Weights update flag. Default: False.
Returns:
Tensor, output tensor.
Examples:
ResidualBlock(3,256,stride=2,down_sample=True)
"""
expansion = 4
def __init__(self,
in_channels,
out_channels,
stride=1,
down_sample=False,
momentum=0.1,
training=False,
weights_update=False):
super(ResidualBlockUsing, self).__init__()
self.affine = weights_update
out_chls = out_channels // self.expansion
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0)
self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1)
self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0)
self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training)
if training:
self.bn1 = self.bn1.set_train()
self.bn2 = self.bn2.set_train()
self.bn3 = self.bn3.set_train()
if not weights_update:
self.conv1.weight.requires_grad = False
self.conv2.weight.requires_grad = False
self.conv3.weight.requires_grad = False
self.relu = P.ReLU()
self.downsample = down_sample
if self.downsample:
self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine,
use_batch_statistics=training)
if training:
self.bn_down_sample = self.bn_down_sample.set_train()
if not weights_update:
self.conv_down_sample.weight.requires_grad = False
self.add = P.Add()
def construct(self, x):
"""
construct the ResNet V1 residual block
Args:
x: input feature data.
Returns:
Tensor, output tensor.
"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample:
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
out = self.add(out, identity)
out = self.relu(out)
return out

View File

@ -0,0 +1,264 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Resnet50v1.0 backbone."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
def weight_init_ones(shape):
"""Weight init."""
return Tensor(np.full(shape, 0.01).astype(np.float32))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = weight_init_ones(shape)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
"""Batchnorm2D wrapper."""
dtype = np.float32
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
class ResNetFea(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
weights_update (bool): Weight update flag.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> False)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
weights_update=False):
super(ResNetFea, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
bn_training = False
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training)
self.relu = P.ReLU()
self.maxpool = P.MaxPool(kernel_size=3, strides=2, pad_mode="SAME")
self.weights_update = weights_update
if not self.weights_update:
self.conv1.weight.requires_grad = False
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=1,
training=bn_training,
weights_update=self.weights_update)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=2,
training=bn_training,
weights_update=True)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=2,
training=bn_training,
weights_update=True)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=2,
training=bn_training,
weights_update=True)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False):
"""Make block layer."""
layers = []
down_sample = False
if stride != 1 or in_channel != out_channel:
down_sample = True
resblk = block(in_channel,
out_channel,
stride=stride,
down_sample=down_sample,
training=training,
weights_update=weights_update)
layers.append(resblk)
for _ in range(1, layer_num):
resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update)
layers.append(resblk)
return nn.SequentialCell(layers)
def construct(self, x):
"""
construct the ResNet Network
Args:
x: input feature data.
Returns:
Tensor, output tensor.
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
identity = c2
if not self.weights_update:
identity = F.stop_gradient(c2)
c3 = self.layer2(identity)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
return identity, c3, c4, c5
class ResidualBlockUsing_V1(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channels (int) - Input channel.
out_channels (int) - Output channel.
stride (int) - Stride size for the initial convolutional layer. Default: 1.
down_sample (bool) - If to do the downsample in block. Default: False.
momentum (float) - Momentum for batchnorm layer. Default: 0.1.
training (bool) - Training flag. Default: False.
weights_updata (bool) - Weights update flag. Default: False.
Returns:
Tensor, output tensor.
Examples:
ResidualBlock(3,256,stride=2,down_sample=True)
"""
expansion = 4
def __init__(self,
in_channels,
out_channels,
stride=1,
down_sample=False,
momentum=0.1,
training=False,
weights_update=False):
super(ResidualBlockUsing_V1, self).__init__()
self.affine = weights_update
out_chls = out_channels // self.expansion
# self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0)
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=stride, padding=0)
self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
# self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1)
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=1, padding=1)
self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0)
self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training)
if training:
self.bn1 = self.bn1.set_train()
self.bn2 = self.bn2.set_train()
self.bn3 = self.bn3.set_train()
if not weights_update:
self.conv1.weight.requires_grad = False
self.conv2.weight.requires_grad = False
self.conv3.weight.requires_grad = False
self.relu = P.ReLU()
self.downsample = down_sample
if self.downsample:
self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine,
use_batch_statistics=training)
if training:
self.bn_down_sample = self.bn_down_sample.set_train()
if not weights_update:
self.conv_down_sample.weight.requires_grad = False
self.add = P.Add()
def construct(self, x):
"""
construct the ResNet V1 residual block
Args:
x: input feature data.
Returns:
Tensor, output tensor.
"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample:
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
out = self.add(out, identity)
out = self.relu(out)
return out

View File

@ -0,0 +1,179 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn ROIAlign module."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.nn import layer as L
from mindspore.common.tensor import Tensor
class ROIAlign(nn.Cell):
"""
Extract RoI features from multiple feature map.
Args:
out_size_h (int) - RoI height.
out_size_w (int) - RoI width.
spatial_scale (int) - RoI spatial scale.
sample_num (int) - RoI sample number.
"""
def __init__(self,
out_size_h,
out_size_w,
spatial_scale,
sample_num=0):
super(ROIAlign, self).__init__()
self.out_size = (out_size_h, out_size_w)
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1],
self.spatial_scale, self.sample_num)
def construct(self, features, rois):
return self.align_op(features, rois)
def __repr__(self):
format_str = self.__class__.__name__
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format(
self.out_size, self.spatial_scale, self.sample_num)
return format_str
class SingleRoIExtractor(nn.Cell):
"""
Extract RoI features from a single level feature map.
If there are multiple input feature levels, each RoI is mapped to a level
according to its scale.
Args:
config (dict): Config
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
batch_size (int) Batchsize.
finest_scale (int): Scale threshold of mapping to level 0.
"""
def __init__(self,
config,
roi_layer,
out_channels,
featmap_strides,
batch_size=1,
finest_scale=56):
super(SingleRoIExtractor, self).__init__()
cfg = config
self.train_batch_size = batch_size
self.out_channels = out_channels
self.featmap_strides = featmap_strides
self.num_levels = len(self.featmap_strides)
self.out_size = config.roi_layer.out_size
self.sample_num = config.roi_layer.sample_num
self.roi_layers = self.build_roi_layers(self.featmap_strides)
self.roi_layers = L.CellList(self.roi_layers)
self.sqrt = P.Sqrt()
self.log = P.Log()
self.finest_scale_ = finest_scale
self.clamp = C.clip_by_value
self.cast = P.Cast()
self.equal = P.Equal()
self.select = P.Select()
_mode_16 = False
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32
self.set_train_local(cfg, training=True)
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
cfg = config
# Init tensor
self.batch_size = cfg.roi_sample_num if self.training_local else cfg.rpn_max_num
self.batch_size = self.train_batch_size*self.batch_size \
if self.training_local else cfg.test_batch_size*self.batch_size
self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype))
finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_
self.finest_scale = Tensor(finest_scale)
self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6))
self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32))
self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1))
self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2)
self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels,
self.out_size, self.out_size)), dtype=self.dtype))
def num_inputs(self):
return len(self.featmap_strides)
def init_weights(self):
pass
def log2(self, value):
return self.log(value) / self.log(self.twos)
def build_roi_layers(self, featmap_strides):
roi_layers = []
for s in featmap_strides:
layer_cls = ROIAlign(self.out_size, self.out_size,
spatial_scale=1 / s,
sample_num=self.sample_num)
roi_layers.append(layer_cls)
return roi_layers
def _c_map_roi_levels(self, rois):
"""Map rois to corresponding feature levels by scales.
- scale < finest_scale * 2: level 0
- finest_scale * 2 <= scale < finest_scale * 4: level 1
- finest_scale * 4 <= scale < finest_scale * 8: level 2
- scale >= finest_scale * 8: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
"""
scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \
self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones)
target_lvls = self.log2(scale / self.finest_scale + self.epslion)
target_lvls = P.Floor()(target_lvls)
target_lvls = self.cast(target_lvls, mstype.int32)
target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels)
return target_lvls
def construct(self, rois, feat1, feat2, feat3, feat4):
feats = (feat1, feat2, feat3, feat4)
res = self.res_
target_lvls = self._c_map_roi_levels(rois)
for i in range(self.num_levels):
mask = self.equal(target_lvls, P.ScalarToArray()(i))
mask = P.Reshape()(mask, (-1, 1, 1, 1))
roi_feats_t = self.roi_layers[i](feats[i], rois)
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32),\
(1, 256, self.out_size, self.out_size)), mstype.bool_)
res = self.select(mask, roi_feats_t, res)
return res

View File

@ -0,0 +1,318 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RPN for fasterRCNN"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from .bbox_assign_sample import BboxAssignSample
class RpnRegClsBlock(nn.Cell):
"""
Rpn reg cls block for rpn layer
Args:
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
weight_conv (Tensor) - weight init for rpn conv.
bias_conv (Tensor) - bias init for rpn conv.
weight_cls (Tensor) - weight init for rpn cls conv.
bias_cls (Tensor) - bias init for rpn cls conv.
weight_reg (Tensor) - weight init for rpn reg conv.
bias_reg (Tensor) - bias init for rpn reg conv.
Returns:
Tensor, output tensor.
"""
def __init__(self,
in_channels,
feat_channels,
num_anchors,
cls_out_channels,
weight_conv,
bias_conv,
weight_cls,
bias_cls,
weight_reg,
bias_reg):
super(RpnRegClsBlock, self).__init__()
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same',
has_bias=True, weight_init=weight_conv, bias_init=bias_conv)
self.relu = nn.ReLU()
self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid',
has_bias=True, weight_init=weight_cls, bias_init=bias_cls)
self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid',
has_bias=True, weight_init=weight_reg, bias_init=bias_reg)
def construct(self, x):
x = self.relu(self.rpn_conv(x))
x1 = self.rpn_cls(x)
x2 = self.rpn_reg(x)
return x1, x2
class RPN(nn.Cell):
"""
ROI proposal network..
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
Tuple, tuple of output tensor.
Examples:
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
num_anchors=3, cls_out_channels=512)
"""
def __init__(self,
config,
batch_size,
in_channels,
feat_channels,
num_anchors,
cls_out_channels):
super(RPN, self).__init__()
cfg_rpn = config
self.dtype = np.float32
self.ms_type = mstype.float32
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = ()
self.feature_anchor_shape = ()
self.slice_index += (0,)
index = 0
for shape in cfg_rpn.feature_shapes:
self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,)
self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,)
index += 1
self.num_anchors = num_anchors
self.batch_size = batch_size
self.test_batch_size = cfg_rpn.test_batch_size
self.num_layers = 5
self.real_ratio = Tensor(np.ones((1, 1)).astype(self.dtype))
self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels,
num_anchors, cls_out_channels))
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=0)
self.fill = P.Fill()
self.placeh1 = Tensor(np.ones((1,)).astype(self.dtype))
self.trans_shape = (0, 2, 3, 1)
self.reshape_shape_reg = (-1, 4)
self.reshape_shape_cls = (-1,)
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(self.dtype))
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(self.dtype))
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(self.dtype))
self.num_bboxes = cfg_rpn.num_bboxes
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
self.CheckValid = P.CheckValid()
self.sum_loss = P.ReduceSum()
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.loss = Tensor(np.zeros((1,)).astype(self.dtype))
self.clsloss = Tensor(np.zeros((1,)).astype(self.dtype))
self.regloss = Tensor(np.zeros((1,)).astype(self.dtype))
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
"""
make rpn layer for rpn proposal network
Args:
num_layers (int) - layer num.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
List, list of RpnRegClsBlock cells.
"""
rpn_layer = []
shp_weight_conv = (feat_channels, in_channels, 3, 3)
shp_bias_conv = (feat_channels,)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=self.ms_type).to_tensor()
bias_conv = initializer(0, shape=shp_bias_conv, dtype=self.ms_type).to_tensor()
shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
shp_bias_cls = (num_anchors * cls_out_channels,)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=self.ms_type).to_tensor()
bias_cls = initializer(0, shape=shp_bias_cls, dtype=self.ms_type).to_tensor()
shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
shp_bias_reg = (num_anchors * 4,)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=self.ms_type).to_tensor()
bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor()
for i in range(num_layers):
rpn_reg_cls_block = RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg)
if self.device_type == "Ascend":
rpn_reg_cls_block.to_float(mstype.float16)
rpn_layer.append(rpn_reg_cls_block)
for i in range(1, num_layers):
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight
rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight
rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias
rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias
rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias
return rpn_layer
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
loss_print = ()
rpn_cls_score = ()
rpn_bbox_pred = ()
rpn_cls_score_total = ()
rpn_bbox_pred_total = ()
for i in range(self.num_layers):
x1, x2 = self.rpn_convs_list[i](inputs[i])
rpn_cls_score_total = rpn_cls_score_total + (x1,)
rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,)
x1 = self.transpose(x1, self.trans_shape)
x1 = self.reshape(x1, self.reshape_shape_cls)
x2 = self.transpose(x2, self.trans_shape)
x2 = self.reshape(x2, self.reshape_shape_reg)
rpn_cls_score = rpn_cls_score + (x1,)
rpn_bbox_pred = rpn_bbox_pred + (x2,)
loss = self.loss
clsloss = self.clsloss
regloss = self.regloss
bbox_targets = ()
bbox_weights = ()
labels = ()
label_weights = ()
output = ()
if self.training:
for i in range(self.batch_size):
multi_level_flags = ()
anchor_list_tuple = ()
for j in range(self.num_layers):
res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])),
mstype.int32)
multi_level_flags = multi_level_flags + (res,)
anchor_list_tuple = anchor_list_tuple + (anchor_list[j],)
valid_flag_list = self.concat(multi_level_flags)
anchor_using_list = self.concat(anchor_list_tuple)
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
gt_labels_i,
self.cast(valid_flag_list,
mstype.bool_),
anchor_using_list, gt_valids_i)
bbox_target = self.cast(bbox_target, self.ms_type)
bbox_weight = self.cast(bbox_weight, self.ms_type)
label = self.cast(label, self.ms_type)
label_weight = self.cast(label_weight, self.ms_type)
for j in range(self.num_layers):
begin = self.slice_index[j]
end = self.slice_index[j + 1]
stride = 1
bbox_targets += (bbox_target[begin:end:stride, ::],)
bbox_weights += (bbox_weight[begin:end:stride],)
labels += (label[begin:end:stride],)
label_weights += (label_weight[begin:end:stride],)
for i in range(self.num_layers):
bbox_target_using = ()
bbox_weight_using = ()
label_using = ()
label_weight_using = ()
for j in range(self.batch_size):
bbox_target_using += (bbox_targets[i + (self.num_layers * j)],)
bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],)
label_using += (labels[i + (self.num_layers * j)],)
label_weight_using += (label_weights[i + (self.num_layers * j)],)
bbox_target_with_batchsize = self.concat(bbox_target_using)
bbox_weight_with_batchsize = self.concat(bbox_weight_using)
label_with_batchsize = self.concat(label_using)
label_weight_with_batchsize = self.concat(label_weight_using)
# stop
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
cls_score_i = self.cast(rpn_cls_score[i], self.ms_type)
reg_score_i = self.cast(rpn_bbox_pred[i], self.ms_type)
loss_cls = self.loss_cls(cls_score_i, label_)
loss_cls_item = loss_cls * label_weight_
loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total
loss_reg = self.loss_bbox(reg_score_i, bbox_target_)
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4))
loss_reg = loss_reg * bbox_weight_
loss_reg_item = self.sum_loss(loss_reg, (1,))
loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total
loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item
loss += loss_total
loss_print += (loss_total, loss_cls_item, loss_reg_item)
clsloss += loss_cls_item
regloss += loss_reg_item
output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print)
else:
output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1)
return output

View File

@ -0,0 +1,62 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""
convert resnet pretrain model to faster_rcnn backbone pretrain model
"""
from mindspore.train.serialization import load_checkpoint, save_checkpoint
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from .model_utils.config import config
def load_weights(model_path, use_fp16_weight):
"""
load resnet pretrain checkpoint file.
Args:
model_path (str): resnet pretrain checkpoint file .
use_fp16_weight(bool): whether save weight into float16.
Returns:
parameter list(list): pretrain model weight list.
"""
ms_ckpt = load_checkpoint(model_path)
weights = {}
for msname in ms_ckpt:
if msname.startswith("layer") or msname.startswith("conv1") or msname.startswith("bn"):
param_name = "backbone." + msname
else:
param_name = msname
if "down_sample_layer.0" in param_name:
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
if "down_sample_layer.1" in param_name:
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
weights[param_name] = ms_ckpt[msname].data.asnumpy()
if use_fp16_weight:
dtype = mstype.float16
else:
dtype = mstype.float32
parameter_dict = {}
for name in weights:
parameter_dict[name] = Parameter(Tensor(weights[name], dtype), name=name)
param_list = []
for key, value in parameter_dict.items():
param_list.append({"name": key, "data": value})
return param_list
if __name__ == "__main__":
parameter_list = load_weights(config.ckpt_file, use_fp16_weight=False)
save_checkpoint(parameter_list, "resnet_backbone.ckpt")

View File

@ -0,0 +1,483 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn dataset"""
from __future__ import division
import os
import numpy as np
from numpy import random
import cv2
import mmcv
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from mindspore.mindrecord import FileWriter
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
"""Calculate the ious between each bbox of bboxes1 and bboxes2.
Args:
bboxes1(ndarray): shape (n, 4)
bboxes2(ndarray): shape (k, 4)
mode(str): iou (intersection over union) or iof (intersection
over foreground)
Returns:
ious(ndarray): shape (n, k)
"""
assert mode in ['iou', 'iof']
bboxes1 = bboxes1.astype(np.float32)
bboxes2 = bboxes2.astype(np.float32)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
ious = np.zeros((rows, cols), dtype=np.float32)
if rows * cols == 0:
return ious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
ious = np.zeros((cols, rows), dtype=np.float32)
exchange = True
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1)
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1)
for i in range(bboxes1.shape[0]):
x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
y_end - y_start + 1, 0)
if mode == 'iou':
union = area1[i] + area2 - overlap
else:
union = area1[i] if not exchange else area2
ious[i, :] = overlap / union
if exchange:
ious = ious.T
return ious
class PhotoMetricDistortion:
"""Photo Metric Distortion"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, img, boxes, labels):
# random brightness
img = img.astype('float32')
if random.randint(2):
delta = random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
return img, boxes, labels
class Expand:
"""expand image"""
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
if to_rgb:
self.mean = mean[::-1]
else:
self.mean = mean
self.min_ratio, self.max_ratio = ratio_range
def __call__(self, img, boxes, labels):
if random.randint(2):
return img, boxes, labels
h, w, c = img.shape
ratio = random.uniform(self.min_ratio, self.max_ratio)
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean).astype(img.dtype)
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
expand_img[top:top + h, left:left + w] = img
img = expand_img
boxes += np.tile((left, top), 2)
return img, boxes, labels
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num, config):
"""rescale operation for image"""
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
if img_data.shape[0] > config.img_height:
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_height), return_scale=True)
scale_factor = scale_factor*scale_factor2
gt_bboxes = gt_bboxes * scale_factor
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_data.shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_data.shape[0] - 1)
pad_h = config.img_height - img_data.shape[0]
pad_w = config.img_width - img_data.shape[1]
assert ((pad_h >= 0) and (pad_w >= 0))
pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype)
pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data
img_shape = (config.img_height, config.img_width, 1.0)
img_shape = np.asarray(img_shape, dtype=np.float32)
return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num)
def rescale_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, config):
"""rescale operation for image of eval"""
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
if img_data.shape[0] > config.img_height:
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_height), return_scale=True)
scale_factor = scale_factor*scale_factor2
pad_h = config.img_height - img_data.shape[0]
pad_w = config.img_width - img_data.shape[1]
assert ((pad_h >= 0) and (pad_w >= 0))
pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype)
pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data
img_shape = np.append(img_shape, (scale_factor, scale_factor))
img_shape = np.asarray(img_shape, dtype=np.float32)
return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num)
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num, config):
"""resize operation for image"""
img_data = img
img_data, w_scale, h_scale = mmcv.imresize(
img_data, (config.img_width, config.img_height), return_scale=True)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
img_shape = (config.img_height, config.img_width, 1.0)
img_shape = np.asarray(img_shape, dtype=np.float32)
gt_bboxes = gt_bboxes * scale_factor
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, config):
"""resize operation for image of eval"""
img_data = img
img_data, w_scale, h_scale = mmcv.imresize(
img_data, (config.img_width, config.img_height), return_scale=True)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
img_shape = np.append(img_shape, (h_scale, w_scale))
img_shape = np.asarray(img_shape, dtype=np.float32)
gt_bboxes = gt_bboxes * scale_factor
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num, config):
"""impad operation for image"""
img_data = mmcv.impad(img, (config.img_height, config.img_width))
img_data = img_data.astype(np.float32)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""imnormalize operation for image"""
img_data = mmcv.imnormalize(img, np.array([123.675, 116.28, 103.53]), np.array([58.395, 57.12, 57.375]), True)
img_data = img_data.astype(np.float32)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""flip operation for image"""
img_data = img
img_data = mmcv.imflip(img_data)
flipped = gt_bboxes.copy()
_, w, _ = img_data.shape
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
return (img_data, img_shape, flipped, gt_label, gt_num)
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""transpose operation for image"""
img_data = img.transpose(2, 0, 1).copy()
img_data = img_data.astype(np.float32)
img_shape = img_shape.astype(np.float32)
gt_bboxes = gt_bboxes.astype(np.float32)
gt_label = gt_label.astype(np.int32)
gt_num = gt_num.astype(np.bool)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""photo crop operation for image"""
random_photo = PhotoMetricDistortion()
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""expand operation for image"""
expand = Expand()
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
return (img, img_shape, gt_bboxes, gt_label, gt_num)
def preprocess_fn(image, box, is_training, config):
"""Preprocess function for dataset."""
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
image_shape = image_shape[:2]
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
if config.keep_ratio:
input_data = rescale_column_test(*input_data, config=config)
else:
input_data = resize_column_test(*input_data, config=config)
input_data = imnormalize_column(*input_data)
output_data = transpose_column(*input_data)
return output_data
def _data_aug(image, box, is_training):
"""Data augmentation function."""
image_bgr = image.copy()
image_bgr[:, :, 0] = image[:, :, 2]
image_bgr[:, :, 1] = image[:, :, 1]
image_bgr[:, :, 2] = image[:, :, 0]
image_shape = image_bgr.shape[:2]
gt_box = box[:, :4]
gt_label = box[:, 4]
gt_iscrowd = box[:, 5]
pad_max_number = 128
gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant", constant_values=0)
gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1)
gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=1)
gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32)
if not is_training:
return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)
flip = (np.random.rand() < config.flip_ratio)
expand = (np.random.rand() < config.expand_ratio)
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
if expand:
input_data = expand_column(*input_data)
if config.keep_ratio:
input_data = rescale_column(*input_data, config=config)
else:
input_data = resize_column(*input_data, config=config)
input_data = imnormalize_column(*input_data)
if flip:
input_data = flip_column(*input_data)
output_data = transpose_column(*input_data)
return output_data
return _data_aug(image, box, is_training)
def create_coco_label(is_training, config):
"""Get image path and annotation from COCO."""
from pycocotools.coco import COCO
coco_root = config.coco_root
data_type = config.val_data_type
if is_training:
data_type = config.train_data_type
# Classes need to train or test.
train_cls = config.coco_classes
train_cls_dict = {}
for i, cls in enumerate(train_cls):
train_cls_dict[cls] = i
anno_json = os.path.join(coco_root, config.instance_set.format(data_type))
coco = COCO(anno_json)
classs_dict = {}
cat_ids = coco.loadCats(coco.getCatIds())
for cat in cat_ids:
classs_dict[cat["id"]] = cat["name"]
image_ids = coco.getImgIds()
image_files = []
image_anno_dict = {}
for img_id in image_ids:
image_info = coco.loadImgs(img_id)
file_name = image_info[0]["file_name"]
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = coco.loadAnns(anno_ids)
# image_path = os.path.join(coco_root, data_type, file_name)
image_path = os.path.join(coco_root, file_name)
annos = []
for label in anno:
bbox = label["bbox"]
class_name = classs_dict[label["category_id"]]
if class_name in train_cls:
x1, x2 = bbox[0], bbox[0] + bbox[2]
y1, y2 = bbox[1], bbox[1] + bbox[3]
annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])])
image_files.append(image_path)
if annos:
image_anno_dict[image_path] = np.array(annos)
else:
image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
return image_files, image_anno_dict
def anno_parser(annos_str):
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos
def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path."""
image_files = []
image_anno_dict = {}
if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.")
if not os.path.isfile(anno_path):
raise RuntimeError("Annotation file is not valid.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ')
file_name = line_split[0]
image_path = os.path.join(image_dir, file_name)
if os.path.isfile(image_path):
image_anno_dict[image_path] = anno_parser(line_split[1:])
image_files.append(image_path)
return image_files, image_anno_dict
def data_to_mindrecord_byte_image(config, dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8):
"""Create MindRecord file."""
mindrecord_dir = config.mindrecord_dir
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
if dataset == "coco":
image_files, image_anno_dict = create_coco_label(is_training, config=config)
else:
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)
fasterrcnn_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 6]},
}
writer.add_schema(fasterrcnn_json, "fasterrcnn_json")
for image_name in image_files:
with open(image_name, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def create_fasterrcnn_dataset(config, mindrecord_file, batch_size=2, device_num=1, rank_id=0, is_training=True,
num_parallel_workers=8, python_multiprocessing=False):
"""Create FasterRcnn dataset with MindDataset."""
cv2.setNumThreads(0)
de.config.set_prefetch_size(8)
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
num_parallel_workers=4, shuffle=is_training)
decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training, config=config))
if is_training:
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
operations=compose_map_func, python_multiprocessing=python_multiprocessing,
num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
else:
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
operations=compose_map_func,
num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
return ds

View File

@ -0,0 +1,40 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for fasterrcnn"""
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, steps_per_epoch):
"""dynamic learning rate generator"""
base_lr = config.base_lr
total_steps = steps_per_epoch * (config.epoch_size + 1)
warmup_steps = int(config.warmup_step)
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '5')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,122 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -0,0 +1,150 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn training network wrapper."""
import time
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = time.time()
time_stamp_init = True
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.count += 1
self.loss_sum += float(loss)
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
total_loss = self.loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
total_loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.loss_sum = 0
class LossNet(nn.Cell):
"""FasterRcnn loss method"""
def construct(self, x1, x2, x3, x4, x5, x6):
return x1 + x2
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone
class TrainOneStepCell(nn.Cell):
"""
Network training package class.
Append an optimizer to the training network after that the construct function
can be called to create the backward graph.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
self.reduce_flag = reduce_flag
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
weights = self.weights
loss = self.network(x, img_shape, gt_bboxe, gt_label, gt_num)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,225 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""coco eval for fasterrcnn"""
import json
import numpy as np
import mmcv
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
_init_value = np.array(0.0)
summary_init = {
'Precision/mAP': _init_value,
'Precision/mAP@.50IOU': _init_value,
'Precision/mAP@.75IOU': _init_value,
'Precision/mAP (small)': _init_value,
'Precision/mAP (medium)': _init_value,
'Precision/mAP (large)': _init_value,
'Recall/AR@1': _init_value,
'Recall/AR@10': _init_value,
'Recall/AR@100': _init_value,
'Recall/AR@100 (small)': _init_value,
'Recall/AR@100 (medium)': _init_value,
'Recall/AR@100 (large)': _init_value,
}
def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False):
"""coco eval for fasterrcnn"""
anns = json.load(open(result_files['bbox']))
if not anns:
return summary_init
if mmcv.is_str(coco):
coco = COCO(coco)
assert isinstance(coco, COCO)
for res_type in result_types:
result_file = result_files[res_type]
assert result_file.endswith('.json')
coco_dets = coco.loadRes(result_file)
gt_img_ids = coco.getImgIds()
det_img_ids = coco_dets.getImgIds()
iou_type = 'bbox' if res_type == 'proposal' else res_type
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
tgt_ids = gt_img_ids if not single_result else det_img_ids
if single_result:
res_dict = dict()
for id_i in tgt_ids:
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
cocoEval.params.imgIds = [id_i]
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]})
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
cocoEval.params.imgIds = tgt_ids
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
summary_metrics = {
'Precision/mAP': cocoEval.stats[0],
'Precision/mAP@.50IOU': cocoEval.stats[1],
'Precision/mAP@.75IOU': cocoEval.stats[2],
'Precision/mAP (small)': cocoEval.stats[3],
'Precision/mAP (medium)': cocoEval.stats[4],
'Precision/mAP (large)': cocoEval.stats[5],
'Recall/AR@1': cocoEval.stats[6],
'Recall/AR@10': cocoEval.stats[7],
'Recall/AR@100': cocoEval.stats[8],
'Recall/AR@100 (small)': cocoEval.stats[9],
'Recall/AR@100 (medium)': cocoEval.stats[10],
'Recall/AR@100 (large)': cocoEval.stats[11],
}
return summary_metrics
def xyxy2xywh(bbox):
_bbox = bbox.tolist()
return [
_bbox[0],
_bbox[1],
_bbox[2] - _bbox[0] + 1,
_bbox[3] - _bbox[1] + 1,
]
def bbox2result_1image(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (Tensor): shape (n, 5)
labels (Tensor): shape (n, )
num_classes (int): class number, including background class
Returns:
list(ndarray): bbox results of each class
"""
if bboxes.shape[0] == 0:
result = [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)]
else:
result = [bboxes[labels == i, :] for i in range(num_classes - 1)]
return result
def proposal2json(dataset, results):
"""convert proposal to json mode"""
img_ids = dataset.getImgIds()
json_results = []
dataset_len = dataset.get_dataset_size()*2
for idx in range(dataset_len):
img_id = img_ids[idx]
bboxes = results[idx]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = 1
json_results.append(data)
return json_results
def det2json(dataset, results):
"""convert det to json mode"""
cat_ids = dataset.getCatIds()
img_ids = dataset.getImgIds()
json_results = []
dataset_len = len(img_ids)
for idx in range(dataset_len):
img_id = img_ids[idx]
if idx == len(results): break
result = results[idx]
for label, result_label in enumerate(result):
bboxes = result_label
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = cat_ids[label]
json_results.append(data)
return json_results
def segm2json(dataset, results):
"""convert segm to json mode"""
bbox_json_results = []
segm_json_results = []
for idx in range(len(dataset)):
img_id = dataset.img_ids[idx]
det, seg = results[idx]
for label, det_label in enumerate(det):
# bbox results
bboxes = det_label
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = dataset.cat_ids[label]
bbox_json_results.append(data)
if len(seg) == 2:
segms = seg[0][label]
mask_score = seg[1][label]
else:
segms = seg[label]
mask_score = [bbox[4] for bbox in bboxes]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['score'] = float(mask_score[i])
data['category_id'] = dataset.cat_ids[label]
segms[i]['counts'] = segms[i]['counts'].decode()
data['segmentation'] = segms[i]
segm_json_results.append(data)
return bbox_json_results, segm_json_results
def results2json(dataset, results, out_file):
"""convert result convert to json mode"""
result_files = dict()
if isinstance(results[0], list):
json_results = det2json(dataset, results)
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
mmcv.dump(json_results, result_files['bbox'])
elif isinstance(results[0], tuple):
json_results = segm2json(dataset, results)
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['segm'] = '{}.{}.json'.format(out_file, 'segm')
mmcv.dump(json_results[0], result_files['bbox'])
mmcv.dump(json_results[1], result_files['segm'])
elif isinstance(results[0], np.ndarray):
json_results = proposal2json(dataset, results)
result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal')
mmcv.dump(json_results, result_files['proposal'])
else:
raise TypeError('invalid type of results')
return result_files

View File

@ -0,0 +1,216 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train FasterRcnn and get checkpoint files."""
import os
import time
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, Parameter
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.train.callback import TimeMonitor
from mindspore.train import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import save_checkpoint
from mindspore.nn import SGD
from mindspore.common import set_seed
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
from src.dataset import create_fasterrcnn_dataset
from src.lr_schedule import dynamic_lr
from src.model_utils.config import config
from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet
set_seed(1)
device_target = config.device_target
server_mode = config.server_mode
ms_role = config.ms_role
worker_num = config.worker_num
server_num = config.server_num
scheduler_ip = config.scheduler_ip
scheduler_port = config.scheduler_port
fl_server_port = config.fl_server_port
start_fl_job_threshold = config.start_fl_job_threshold
start_fl_job_time_window = config.start_fl_job_time_window
update_model_ratio = config.update_model_ratio
update_model_time_window = config.update_model_time_window
fl_name = config.fl_name
fl_iteration_num = config.fl_iteration_num
client_epoch_num = config.client_epoch_num
client_batch_size = config.client_batch_size
client_learning_rate = config.client_learning_rate
worker_step_num_per_iteration = config.worker_step_num_per_iteration
scheduler_manage_port = config.scheduler_manage_port
config_file_path = config.config_file_path
encrypt_type = config.encrypt_type
user_id = config.user_id
ctx = {
"enable_fl": True,
"server_mode": server_mode,
"ms_role": ms_role,
"worker_num": worker_num,
"server_num": server_num,
"scheduler_ip": scheduler_ip,
"scheduler_port": scheduler_port,
"fl_server_port": fl_server_port,
"start_fl_job_threshold": start_fl_job_threshold,
"start_fl_job_time_window": start_fl_job_time_window,
"update_model_ratio": update_model_ratio,
"update_model_time_window": update_model_time_window,
"fl_name": fl_name,
"fl_iteration_num": fl_iteration_num,
"client_epoch_num": client_epoch_num,
"client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"worker_step_num_per_iteration": worker_step_num_per_iteration,
"scheduler_manage_port": scheduler_manage_port,
"config_file_path": config_file_path,
"encrypt_type": encrypt_type
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
context.set_fl_context(**ctx)
# print(**ctx, flush=True)
# context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=get_device_id())
# context.set_context(enable_graph_kernel=True)
rank = 0
device_num = 1
user = "mindrecord_" + str(user_id)
def train_fasterrcnn_():
""" train_fasterrcnn_ """
print("Start create dataset!", flush=True)
# It will generate mindrecord file in config.mindrecord_dir,
# and the file name is FasterRcnn.mindrecord0, 1, ... file_num.
prefix = "FasterRcnn.mindrecord"
mindrecord_dir = config.dataset_path
mindrecord_file = os.path.join(mindrecord_dir, user, prefix)
print("CHECKING MINDRECORD FILES ...", mindrecord_file, flush=True)
if rank == 0 and not os.path.exists(mindrecord_file):
print("image_dir or anno_path not exits.", flush=True)
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
print("CHECKING MINDRECORD FILES DONE!", flush=True)
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
dataset = create_fasterrcnn_dataset(config, mindrecord_file, batch_size=config.batch_size,
device_num=device_num, rank_id=rank,
num_parallel_workers=config.num_parallel_workers,
python_multiprocessing=config.python_multiprocessing)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!", flush=True)
return dataset_size, dataset
class StartFLJob(nn.Cell):
def __init__(self, data_size):
super(StartFLJob, self).__init__()
self.start_fl_job = P.StartFLJob(data_size)
def construct(self):
return self.start_fl_job()
class UpdateAndGetModel(nn.Cell):
def __init__(self, weights):
super(UpdateAndGetModel, self).__init__()
self.update_model = P.UpdateModel()
self.get_model = P.GetModel()
self.weights = weights
def construct(self):
self.update_model(self.weights)
get_model = self.get_model(self.weights)
return get_model
def train():
""" train_fasterrcnn """
dataset_size, dataset = train_fasterrcnn_()
net = Faster_Rcnn_Resnet(config=config)
net = net.set_train()
load_path = config.pre_trained
# load_path = ""
if load_path != "":
param_dict = load_checkpoint(load_path)
key_mapping = {'down_sample_layer.1.beta': 'bn_down_sample.beta',
'down_sample_layer.1.gamma': 'bn_down_sample.gamma',
'down_sample_layer.0.weight': 'conv_down_sample.weight',
'down_sample_layer.1.moving_mean': 'bn_down_sample.moving_mean',
'down_sample_layer.1.moving_variance': 'bn_down_sample.moving_variance',
}
for oldkey in list(param_dict.keys()):
if not oldkey.startswith(('backbone', 'end_point', 'global_step', 'learning_rate', 'moments', 'momentum')):
data = param_dict.pop(oldkey)
newkey = 'backbone.' + oldkey
param_dict[newkey] = data
oldkey = newkey
for k, v in key_mapping.items():
if k in oldkey:
newkey = oldkey.replace(k, v)
param_dict[newkey] = param_dict.pop(oldkey)
break
for item in list(param_dict.keys()):
if not item.startswith('backbone'):
param_dict.pop(item)
for key, value in param_dict.items():
tensor = value.asnumpy().astype(np.float32)
param_dict[key] = Parameter(tensor, key)
load_param_into_net(net, param_dict)
loss = LossNet()
lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32)
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
if config.run_distribute:
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb]
model = Model(net)
ckpt_path1 = os.path.join("ckpt", user)
os.makedirs(ckpt_path1)
print("====================", config.client_epoch_num, fl_iteration_num, flush=True)
for iter_num in range(fl_iteration_num):
if context.get_fl_context("ms_role") == "MS_WORKER":
start_fl_job = StartFLJob(dataset_size * config.batch_size)
start_fl_job()
model.train(config.client_epoch_num, dataset, callbacks=cb)
if context.get_fl_context("ms_role") == "MS_WORKER":
update_and_get_model = UpdateAndGetModel(opt.parameters)
update_and_get_model()
ckpt_name = user + "-fast-rcnn-" + str(iter_num) + "epoch.ckpt"
ckpt_path = os.path.join(ckpt_path1, ckpt_name)
save_checkpoint(net, ckpt_path)
if __name__ == '__main__':
train()