forked from OSSInnovation/mindspore
!6356 improve maskrcnn precision and performance
Merge pull request !6356 from gengdongjie/master
This commit is contained in:
commit
74255d9faf
|
@ -26,9 +26,8 @@ tile_op_info = TBERegOp("Tile") \
|
|||
.attr("multiples", "optional", "listInt", "all")\
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class CropAndResize(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
method (str): An optional string that specifies the sampling method for resizing.
|
||||
It can be either "bilinear" or "nearest". Default: "bilinear"
|
||||
It can be "bilinear", "nearest" or "bilinear_v2". The option "bilinear" stands for standard bilinear
|
||||
interpolation algorithm, while "bilinear_v2" may result in better result in some cases. Default: "bilinear"
|
||||
extrapolation_value (float): An optional float value used extrapolation, if applicable. Default: 0.
|
||||
|
||||
Inputs:
|
||||
|
@ -81,7 +82,7 @@ class CropAndResize(PrimitiveWithInfer):
|
|||
"""init CropAndResize"""
|
||||
self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y'])
|
||||
validator.check_value_type("method", method, [str], self.name)
|
||||
validator.check_string("method", method, ["bilinear", "nearest"], self.name)
|
||||
validator.check_string("method", method, ["bilinear", "nearest", "bilinear_v2"], self.name)
|
||||
self.method = method
|
||||
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
|
||||
self.extrapolation_value = extrapolation_value
|
||||
|
|
|
@ -406,9 +406,14 @@ def create_coco_label(is_training):
|
|||
image_anno_dict = {}
|
||||
masks = {}
|
||||
masks_shape = {}
|
||||
for img_id in image_ids:
|
||||
images_num = len(image_ids)
|
||||
for ind, img_id in enumerate(image_ids):
|
||||
image_info = coco.loadImgs(img_id)
|
||||
file_name = image_info[0]["file_name"]
|
||||
image_path = os.path.join(coco_root, data_type, file_name)
|
||||
if not os.path.isfile(image_path):
|
||||
print("{}/{}: {} is in annotations but not exist".format(ind + 1, images_num, image_path))
|
||||
continue
|
||||
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = coco.loadAnns(anno_ids)
|
||||
image_path = os.path.join(coco_root, data_type, file_name)
|
||||
|
@ -416,7 +421,8 @@ def create_coco_label(is_training):
|
|||
instance_masks = []
|
||||
image_height = coco.imgs[img_id]["height"]
|
||||
image_width = coco.imgs[img_id]["width"]
|
||||
print("image file name: ", file_name)
|
||||
if (ind + 1) % 10 == 0:
|
||||
print("{}/{}: parsing annotation for image={}".format(ind + 1, images_num, file_name))
|
||||
if not is_training:
|
||||
image_files.append(image_path)
|
||||
image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
|
||||
|
@ -478,13 +484,16 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="mask
|
|||
}
|
||||
writer.add_schema(maskrcnn_json, "maskrcnn_json")
|
||||
|
||||
for image_name in image_files:
|
||||
image_files_num = len(image_files)
|
||||
for ind, image_name in enumerate(image_files):
|
||||
with open(image_name, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
||||
mask = masks[image_name]
|
||||
mask_shape = masks_shape[image_name]
|
||||
row = {"image": img, "annotation": annos, "mask": mask, "mask_shape": mask_shape}
|
||||
if (ind + 1) % 10 == 0:
|
||||
print("writing {}/{} into mindrecord".format(ind + 1, image_files_num))
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ class BboxAssignSampleForRcnn(nn.Cell):
|
|||
self.round = P.Round()
|
||||
self.image_h_w = Tensor([cfg.img_height, cfg.img_width, cfg.img_height, cfg.img_width], dtype=mstype.float16)
|
||||
self.range = nn.Range(start=0, limit=cfg.num_expected_pos_stage2)
|
||||
self.crop_and_resize = P.CropAndResize()
|
||||
self.crop_and_resize = P.CropAndResize(method="bilinear_v2")
|
||||
self.mask_shape = (cfg.mask_shape[0], cfg.mask_shape[1])
|
||||
self.squeeze_mask_last = P.Squeeze(axis=-1)
|
||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids, gt_masks_i):
|
||||
|
|
|
@ -84,9 +84,10 @@ class FeatPyramidNeck(nn.Cell):
|
|||
self.fpn_convs_.append(fpn_conv)
|
||||
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
|
||||
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
|
||||
self.interpolate1 = P.ResizeNearestNeighbor((48, 80))
|
||||
self.interpolate2 = P.ResizeNearestNeighbor((96, 160))
|
||||
self.interpolate3 = P.ResizeNearestNeighbor((192, 320))
|
||||
self.interpolate1 = P.ResizeBilinear((48, 80))
|
||||
self.interpolate2 = P.ResizeBilinear((96, 160))
|
||||
self.interpolate3 = P.ResizeBilinear((192, 320))
|
||||
self.cast = P.Cast()
|
||||
self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same")
|
||||
|
||||
def construct(self, inputs):
|
||||
|
@ -95,9 +96,9 @@ class FeatPyramidNeck(nn.Cell):
|
|||
x += (self.lateral_convs_list[i](inputs[i]),)
|
||||
|
||||
y = (x[3],)
|
||||
y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),)
|
||||
y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),)
|
||||
y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),)
|
||||
y = y + (x[2] + self.cast(self.interpolate1(y[self.fpn_layer - 4]), mstype.float16),)
|
||||
y = y + (x[1] + self.cast(self.interpolate2(y[self.fpn_layer - 3]), mstype.float16),)
|
||||
y = y + (x[0] + self.cast(self.interpolate3(y[self.fpn_layer - 2]), mstype.float16),)
|
||||
|
||||
z = ()
|
||||
for i in range(self.fpn_layer - 1, -1, -1):
|
||||
|
|
|
@ -247,7 +247,6 @@ def get_seg_masks(mask_pred, det_bboxes, det_labels, img_meta, rescale, num_clas
|
|||
else:
|
||||
img_h = np.round(ori_shape[0] * scale_factor[0]).astype(np.int32)
|
||||
img_w = np.round(ori_shape[1] * scale_factor[1]).astype(np.int32)
|
||||
scale_factor = 1.0
|
||||
|
||||
for i in range(bboxes.shape[0]):
|
||||
bbox = (bboxes[i, :] / 1.0).astype(np.int32)
|
||||
|
@ -256,6 +255,10 @@ def get_seg_masks(mask_pred, det_bboxes, det_labels, img_meta, rescale, num_clas
|
|||
h = max(bbox[3] - bbox[1] + 1, 1)
|
||||
w = min(w, img_w - bbox[0])
|
||||
h = min(h, img_h - bbox[1])
|
||||
if w <= 0 or h <= 0:
|
||||
print("there is invalid proposal bbox, index={} bbox={} w={} h={}".format(i, bbox, w, h))
|
||||
w = max(w, 1)
|
||||
h = max(h, 1)
|
||||
mask_pred_ = mask_pred[i, :, :]
|
||||
im_mask = np.zeros((img_h, img_w), dtype=np.uint8)
|
||||
bbox_mask = mmcv.imresize(mask_pred_, (w, h))
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""train MaskRcnn and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
|
@ -26,7 +27,7 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMoni
|
|||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn import SGD
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
||||
|
@ -71,7 +72,7 @@ if __name__ == '__main__':
|
|||
prefix = "MaskRcnn.mindrecord"
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if rank == 0 and not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if args_opt.dataset == "coco":
|
||||
|
@ -80,14 +81,16 @@ if __name__ == '__main__':
|
|||
data_to_mindrecord_byte_image("coco", True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
raise Exception("coco_root not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("other", True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("IMAGE_DIR or ANNO_PATH not exits.")
|
||||
raise Exception("IMAGE_DIR or ANNO_PATH not exits.")
|
||||
while not os.path.exists(mindrecord_file+".db"):
|
||||
time.sleep(5)
|
||||
|
||||
if not args_opt.only_create_dataset:
|
||||
loss_scale = float(config.loss_scale)
|
||||
|
@ -115,8 +118,8 @@ if __name__ == '__main__':
|
|||
loss = LossNet()
|
||||
lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size),
|
||||
mstype.float32)
|
||||
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
if args_opt.run_distribute:
|
||||
|
|
Loading…
Reference in New Issue