!12167 fix issue I2BDOU fasterrcnn accuracy decreased
From: @zhouneng2 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
ea3f8d6aba
|
@ -46,7 +46,7 @@ if __name__ == '__main__':
|
|||
|
||||
load_param_into_net(net, param_dict_new)
|
||||
|
||||
img = Tensor(np.zeros([config.test_batch_size, 3, config.img_height, config.img_width]), ms.float16)
|
||||
img_metas = Tensor(np.random.uniform(0.0, 1.0, size=[config.test_batch_size, 4]), ms.float16)
|
||||
img = Tensor(np.zeros([config.test_batch_size, 3, config.img_height, config.img_width]), ms.float32)
|
||||
img_metas = Tensor(np.random.uniform(0.0, 1.0, size=[config.test_batch_size, 4]), ms.float32)
|
||||
|
||||
export(net, img, img_metas, file_name=args.file_name, file_format=args.file_format)
|
||||
|
|
|
@ -44,7 +44,7 @@ def get_eval_result(ann_file, img_path):
|
|||
label_result_file = result_path + file_id + "_1.bin"
|
||||
mask_result_file = result_path + file_id + "_2.bin"
|
||||
|
||||
all_bbox = np.fromfile(bbox_result_file, dtype=np.float16).reshape(80000, 5)
|
||||
all_bbox = np.fromfile(bbox_result_file, dtype=np.float32).reshape(80000, 5)
|
||||
all_label = np.fromfile(label_result_file, dtype=np.int32).reshape(80000, 1)
|
||||
all_mask = np.fromfile(mask_result_file, dtype=np.bool_).reshape(80000, 1)
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ import mindspore.nn as nn
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class BboxAssignSample(nn.Cell):
|
||||
|
@ -46,9 +45,8 @@ class BboxAssignSample(nn.Cell):
|
|||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
|
||||
super(BboxAssignSample, self).__init__()
|
||||
cfg = config
|
||||
_mode_16 = bool(context.get_context("device_target") == "Ascend")
|
||||
self.dtype = np.float16 if _mode_16 else np.float32
|
||||
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, self.ms_type)
|
||||
|
@ -89,17 +87,16 @@ class BboxAssignSample(nn.Cell):
|
|||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
|
||||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_inds = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
|
||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
|
||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
|
||||
|
||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
|
||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype))
|
||||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=self.dtype))
|
||||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=self.dtype))
|
||||
|
||||
self.check_gt_one = Tensor(np.full((self.num_gts, 4), -1, dtype=self.dtype))
|
||||
self.check_anchor_two = Tensor(np.full((self.num_bboxes, 4), -2, dtype=self.dtype))
|
||||
|
||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
|
||||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
|
||||
|
|
|
@ -19,7 +19,6 @@ import mindspore.nn as nn
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class BboxAssignSampleForRcnn(nn.Cell):
|
||||
|
@ -46,9 +45,8 @@ class BboxAssignSampleForRcnn(nn.Cell):
|
|||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
|
||||
super(BboxAssignSampleForRcnn, self).__init__()
|
||||
cfg = config
|
||||
_mode_16 = bool(context.get_context("device_target") == "Ascend")
|
||||
self.dtype = np.float16 if _mode_16 else np.float32
|
||||
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.batch_size = batch_size
|
||||
self.neg_iou_thr = cfg.neg_iou_thr_stage2
|
||||
self.pos_iou_thr = cfg.pos_iou_thr_stage2
|
||||
|
@ -61,8 +59,7 @@ class BboxAssignSampleForRcnn(nn.Cell):
|
|||
|
||||
self.add_gt_as_proposals = add_gt_as_proposals
|
||||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32))
|
||||
self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts),
|
||||
dtype=np.int32))
|
||||
self.add_gt_as_proposals_valid = Tensor(np.full(self.num_gts, self.add_gt_as_proposals, dtype=np.int32))
|
||||
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.max_gt = P.ArgMaxWithValue(axis=0)
|
||||
|
@ -87,17 +84,17 @@ class BboxAssignSampleForRcnn(nn.Cell):
|
|||
self.tile = P.Tile()
|
||||
|
||||
# Check
|
||||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=self.dtype))
|
||||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=self.dtype))
|
||||
self.check_gt_one = Tensor(np.full((self.num_gts, 4), -1, dtype=self.dtype))
|
||||
self.check_anchor_two = Tensor(np.full((self.num_bboxes, 4), -2, dtype=self.dtype))
|
||||
|
||||
# Init tensor
|
||||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_inds = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
|
||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.full(num_bboxes, -1, dtype=np.int32))
|
||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
|
||||
|
||||
self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
|
||||
self.gt_ignores = Tensor(np.full(self.num_gts, -1, dtype=np.int32))
|
||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype))
|
||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
|
||||
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=self.dtype))
|
||||
|
|
|
@ -20,7 +20,6 @@ from mindspore.ops import operations as P
|
|||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import context
|
||||
from .resnet50 import ResNetFea, ResidualBlockUsing
|
||||
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
|
||||
from .fpn_neck import FeatPyramidNeck
|
||||
|
@ -30,6 +29,7 @@ from .rpn import RPN
|
|||
from .roi_align import SingleRoIExtractor
|
||||
from .anchor_generator import AnchorGenerator
|
||||
|
||||
|
||||
class Faster_Rcnn_Resnet50(nn.Cell):
|
||||
"""
|
||||
FasterRcnn Network.
|
||||
|
@ -51,9 +51,8 @@ class Faster_Rcnn_Resnet50(nn.Cell):
|
|||
"""
|
||||
def __init__(self, config):
|
||||
super(Faster_Rcnn_Resnet50, self).__init__()
|
||||
_mode_16 = bool(context.get_context("device_target") == "Ascend")
|
||||
self.dtype = np.float16 if _mode_16 else np.float32
|
||||
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.train_batch_size = config.batch_size
|
||||
self.num_classes = config.num_classes
|
||||
self.anchor_scales = config.anchor_scales
|
||||
|
@ -117,22 +116,8 @@ class Faster_Rcnn_Resnet50(nn.Cell):
|
|||
config.num_bboxes_stage2, True)
|
||||
self.decode = P.BoundingBoxDecode(max_shape=(config.img_height, config.img_width), means=self.target_means, \
|
||||
stds=self.target_stds)
|
||||
|
||||
# Roi
|
||||
self.roi_align = SingleRoIExtractor(config,
|
||||
config.roi_layer,
|
||||
config.roi_align_out_channels,
|
||||
config.roi_align_featmap_strides,
|
||||
self.train_batch_size,
|
||||
config.roi_align_finest_scale)
|
||||
self.roi_align.set_train_local(config, True)
|
||||
self.roi_align_test = SingleRoIExtractor(config,
|
||||
config.roi_layer,
|
||||
config.roi_align_out_channels,
|
||||
config.roi_align_featmap_strides,
|
||||
1,
|
||||
config.roi_align_finest_scale)
|
||||
self.roi_align_test.set_train_local(config, False)
|
||||
self.roi_init(config)
|
||||
|
||||
# Rcnn
|
||||
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer['out_size'] * config.roi_layer['out_size'],
|
||||
|
@ -150,7 +135,33 @@ class Faster_Rcnn_Resnet50(nn.Cell):
|
|||
self.greater = P.Greater()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
# Improve speed
|
||||
self.concat_start = min(self.num_classes - 2, 55)
|
||||
self.concat_end = (self.num_classes - 1)
|
||||
|
||||
# Test mode
|
||||
self.test_mode_init(config)
|
||||
|
||||
# Init tensor
|
||||
self.init_tensor(config)
|
||||
|
||||
def roi_init(self, config):
|
||||
self.roi_align = SingleRoIExtractor(config,
|
||||
config.roi_layer,
|
||||
config.roi_align_out_channels,
|
||||
config.roi_align_featmap_strides,
|
||||
self.train_batch_size,
|
||||
config.roi_align_finest_scale)
|
||||
self.roi_align.set_train_local(config, True)
|
||||
self.roi_align_test = SingleRoIExtractor(config,
|
||||
config.roi_layer,
|
||||
config.roi_align_out_channels,
|
||||
config.roi_align_featmap_strides,
|
||||
1,
|
||||
config.roi_align_finest_scale)
|
||||
self.roi_align_test.set_train_local(config, False)
|
||||
|
||||
def test_mode_init(self, config):
|
||||
self.test_batch_size = config.test_batch_size
|
||||
self.split = P.Split(axis=0, output_num=self.test_batch_size)
|
||||
self.split_shape = P.Split(axis=0, output_num=4)
|
||||
|
@ -181,11 +192,7 @@ class Faster_Rcnn_Resnet50(nn.Cell):
|
|||
self.test_topk = P.TopK(sorted=True)
|
||||
self.test_num_proposal = self.test_batch_size * self.rpn_max_num
|
||||
|
||||
# Improve speed
|
||||
self.concat_start = min(self.num_classes - 2, 55)
|
||||
self.concat_end = (self.num_classes - 1)
|
||||
|
||||
# Init tensor
|
||||
def init_tensor(self, config):
|
||||
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
|
||||
dtype=self.dtype) for i in range(self.train_batch_size)]
|
||||
|
||||
|
@ -262,7 +269,6 @@ class Faster_Rcnn_Resnet50(nn.Cell):
|
|||
bboxes_all = bboxes_tuple[0]
|
||||
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
|
||||
|
||||
|
||||
rois = self.cast(rois, mstype.float32)
|
||||
rois = F.stop_gradient(rois)
|
||||
|
||||
|
@ -279,7 +285,6 @@ class Faster_Rcnn_Resnet50(nn.Cell):
|
|||
self.cast(x[2], mstype.float32),
|
||||
self.cast(x[3], mstype.float32))
|
||||
|
||||
|
||||
roi_feats = self.cast(roi_feats, self.ms_type)
|
||||
rcnn_masks = self.concat(mask_tuple)
|
||||
rcnn_masks = F.stop_gradient(rcnn_masks)
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -25,23 +24,20 @@ from mindspore.common.initializer import initializer
|
|||
|
||||
def bias_init_zeros(shape):
|
||||
"""Bias init method."""
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))
|
||||
return Tensor(np.array(np.zeros(shape).astype(np.float32)))
|
||||
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
|
||||
"""Conv2D wrapper."""
|
||||
shape = (out_channels, in_channels, kernel_size, kernel_size)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
|
||||
else:
|
||||
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32).to_tensor()
|
||||
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32).to_tensor()
|
||||
shape_bias = (out_channels,)
|
||||
biass = bias_init_zeros(shape_bias)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass)
|
||||
|
||||
|
||||
class FeatPyramidNeck(nn.Cell):
|
||||
"""
|
||||
Feature pyramid network cell, usually uses as network neck.
|
||||
|
|
|
@ -19,7 +19,6 @@ import mindspore.nn as nn
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Proposal(nn.Cell):
|
||||
|
@ -103,9 +102,8 @@ class Proposal(nn.Cell):
|
|||
self.tile = P.Tile()
|
||||
self.set_train_local(config, training=True)
|
||||
|
||||
_mode_16 = bool(context.get_context("device_target") == "Ascend")
|
||||
self.dtype = np.float16 if _mode_16 else np.float32
|
||||
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
|
||||
self.multi_10 = Tensor(10.0, self.ms_type)
|
||||
|
||||
|
@ -134,10 +132,7 @@ class Proposal(nn.Cell):
|
|||
self.topKv2 = P.TopK(sorted=True)
|
||||
self.topK_shape_stage2 = (self.max_num, 1)
|
||||
self.min_float_num = -65536.0
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
|
||||
else:
|
||||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float32))
|
||||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float32))
|
||||
|
||||
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
|
||||
proposals_tuple = ()
|
||||
|
|
|
@ -28,18 +28,21 @@ class DenseNoTranpose(nn.Cell):
|
|||
"""Dense method"""
|
||||
def __init__(self, input_channels, output_channels, weight_init):
|
||||
super(DenseNoTranpose, self).__init__()
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16))
|
||||
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16))
|
||||
else:
|
||||
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float32))
|
||||
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float32))
|
||||
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float32))
|
||||
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float32))
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
|
||||
def construct(self, x):
|
||||
output = self.bias_add(self.matmul(x, self.weight), self.bias)
|
||||
if self.device_type == "Ascend":
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
output = self.bias_add(self.cast(self.matmul(x, weight), mstype.float32), self.bias)
|
||||
else:
|
||||
output = self.bias_add(self.matmul(x, self.weight), self.bias)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -72,9 +75,8 @@ class Rcnn(nn.Cell):
|
|||
):
|
||||
super(Rcnn, self).__init__()
|
||||
cfg = config
|
||||
_mode_16 = bool(context.get_context("device_target") == "Ascend")
|
||||
self.dtype = np.float16 if _mode_16 else np.float32
|
||||
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(self.dtype))
|
||||
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(self.dtype))
|
||||
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
|
||||
|
|
|
@ -19,14 +19,11 @@ import mindspore.nn as nn
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import context
|
||||
|
||||
|
||||
def weight_init_ones(shape):
|
||||
"""Weight init."""
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16))
|
||||
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01))
|
||||
return Tensor(np.full(shape, 0.01).astype(np.float32))
|
||||
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
|
||||
|
@ -40,8 +37,7 @@ def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mod
|
|||
|
||||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
|
||||
"""Batchnorm2D wrapper."""
|
||||
_mode_16 = bool(context.get_context("device_target") == "Ascend")
|
||||
dtype = np.float16 if _mode_16 else np.float32
|
||||
dtype = np.float32
|
||||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
|
||||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
|
||||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import initializer
|
||||
from .bbox_assign_sample import BboxAssignSample
|
||||
|
@ -100,9 +100,8 @@ class RPN(nn.Cell):
|
|||
cls_out_channels):
|
||||
super(RPN, self).__init__()
|
||||
cfg_rpn = config
|
||||
_mode_16 = bool(context.get_context("device_target") == "Ascend")
|
||||
self.dtype = np.float16 if _mode_16 else np.float32
|
||||
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
|
||||
self.dtype = np.float32
|
||||
self.ms_type = mstype.float32
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.slice_index = ()
|
||||
self.feature_anchor_shape = ()
|
||||
|
|
|
@ -21,13 +21,13 @@ import numpy as np
|
|||
from numpy import random
|
||||
|
||||
import mmcv
|
||||
from mindspore import context
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.config import config
|
||||
import cv2
|
||||
|
||||
|
||||
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
||||
"""Calculate the ious between each bbox of bboxes1 and bboxes2.
|
||||
|
||||
|
@ -73,6 +73,7 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
|||
ious = ious.T
|
||||
return ious
|
||||
|
||||
|
||||
class PhotoMetricDistortion:
|
||||
"""Photo Metric Distortion"""
|
||||
def __init__(self,
|
||||
|
@ -133,6 +134,7 @@ class PhotoMetricDistortion:
|
|||
|
||||
return img, boxes, labels
|
||||
|
||||
|
||||
class Expand:
|
||||
"""expand image"""
|
||||
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
|
||||
|
@ -157,6 +159,7 @@ class Expand:
|
|||
boxes += np.tile((left, top), 2)
|
||||
return img, boxes, labels
|
||||
|
||||
|
||||
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""rescale operation for image"""
|
||||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
||||
|
@ -172,6 +175,7 @@ def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
|||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""resize operation for image"""
|
||||
img_data = img
|
||||
|
@ -189,6 +193,7 @@ def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
|||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""resize operation for image of eval"""
|
||||
img_data = img
|
||||
|
@ -206,18 +211,21 @@ def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
|||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""impad operation for image"""
|
||||
img_data = mmcv.impad(img, (config.img_height, config.img_width))
|
||||
img_data = img_data.astype(np.float32)
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""imnormalize operation for image"""
|
||||
img_data = mmcv.imnormalize(img, np.array([123.675, 116.28, 103.53]), np.array([58.395, 57.12, 57.375]), True)
|
||||
img_data = img_data.astype(np.float32)
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""flip operation for image"""
|
||||
img_data = img
|
||||
|
@ -230,22 +238,19 @@ def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
|||
|
||||
return (img_data, img_shape, flipped, gt_label, gt_num)
|
||||
|
||||
|
||||
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""transpose operation for image"""
|
||||
img_data = img.transpose(2, 0, 1).copy()
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
img_data = img_data.astype(np.float16)
|
||||
img_shape = img_shape.astype(np.float16)
|
||||
gt_bboxes = gt_bboxes.astype(np.float16)
|
||||
else:
|
||||
img_data = img_data.astype(np.float32)
|
||||
img_shape = img_shape.astype(np.float32)
|
||||
gt_bboxes = gt_bboxes.astype(np.float32)
|
||||
img_data = img_data.astype(np.float32)
|
||||
img_shape = img_shape.astype(np.float32)
|
||||
gt_bboxes = gt_bboxes.astype(np.float32)
|
||||
gt_label = gt_label.astype(np.int32)
|
||||
gt_num = gt_num.astype(np.bool)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""photo crop operation for image"""
|
||||
random_photo = PhotoMetricDistortion()
|
||||
|
@ -253,6 +258,7 @@ def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
|||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""expand operation for image"""
|
||||
expand = Expand()
|
||||
|
@ -260,6 +266,7 @@ def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
|||
|
||||
return (img, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def preprocess_fn(image, box, is_training):
|
||||
"""Preprocess function for dataset."""
|
||||
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
|
||||
|
@ -314,6 +321,7 @@ def preprocess_fn(image, box, is_training):
|
|||
|
||||
return _data_aug(image, box, is_training)
|
||||
|
||||
|
||||
def create_coco_label(is_training):
|
||||
"""Get image path and annotation from COCO."""
|
||||
from pycocotools.coco import COCO
|
||||
|
@ -364,6 +372,7 @@ def create_coco_label(is_training):
|
|||
|
||||
return image_files, image_anno_dict
|
||||
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
|
@ -372,6 +381,7 @@ def anno_parser(annos_str):
|
|||
annos.append(anno)
|
||||
return annos
|
||||
|
||||
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
image_files = []
|
||||
|
@ -393,6 +403,7 @@ def filter_valid_data(image_dir, anno_path):
|
|||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
|
||||
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
|
@ -417,6 +428,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast
|
|||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id=0, is_training=True,
|
||||
num_parallel_workers=8):
|
||||
"""Create FasterRcnn dataset with MindDataset."""
|
||||
|
|
|
@ -20,12 +20,14 @@ import mindspore.nn as nn
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import ParameterTuple, context
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
@ -109,11 +111,13 @@ class LossCallBack(Callback):
|
|||
self.rcnn_cls_loss_sum = 0
|
||||
self.rcnn_reg_loss_sum = 0
|
||||
|
||||
|
||||
class LossNet(nn.Cell):
|
||||
"""FasterRcnn loss method"""
|
||||
def construct(self, x1, x2, x3, x4, x5, x6):
|
||||
return x1 + x2
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to compute loss.
|
||||
|
@ -167,10 +171,7 @@ class TrainOneStepCell(nn.Cell):
|
|||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
|
||||
else:
|
||||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
|
||||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
|
||||
self.reduce_flag = reduce_flag
|
||||
if reduce_flag:
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
|
|
@ -127,10 +127,10 @@ if __name__ == '__main__':
|
|||
for item in list(param_dict.keys()):
|
||||
if not item.startswith('backbone'):
|
||||
param_dict.pop(item)
|
||||
if args_opt.device_target == "GPU":
|
||||
for key, value in param_dict.items():
|
||||
tensor = value.asnumpy().astype(np.float32)
|
||||
param_dict[key] = Parameter(tensor, key)
|
||||
|
||||
for key, value in param_dict.items():
|
||||
tensor = value.asnumpy().astype(np.float32)
|
||||
param_dict[key] = Parameter(tensor, key)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
loss = LossNet()
|
||||
|
@ -156,4 +156,4 @@ if __name__ == '__main__':
|
|||
cb += [ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)
|
||||
model.train(config.epoch_size, dataset, callbacks=cb)
|
||||
|
|
Loading…
Reference in New Issue