!14379 solve the problem of sudden increases in losses of fasterrcnn

From: @zhouneng2
Reviewed-by: @oacjiewen,@liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-03-30 21:40:00 +08:00 committed by Gitee
commit 8d42a57093
5 changed files with 24 additions and 6 deletions

View File

@ -19,6 +19,7 @@ import argparse
import time
import numpy as np
from pycocotools.coco import COCO
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed, Parameter
@ -51,7 +52,11 @@ def fasterrcnn_eval(dataset_path, ckpt_path, ann_file):
tensor = value.asnumpy().astype(np.float32)
param_dict[key] = Parameter(tensor, key)
load_param_into_net(net, param_dict)
net.set_train(False)
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
if device_type == "Ascend":
net.to_float(mstype.float16)
eval_iter = 0
total = ds.get_dataset_size()

View File

@ -16,6 +16,7 @@
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
@ -144,6 +145,7 @@ class Faster_Rcnn_Resnet50(nn.Cell):
# Init tensor
self.init_tensor(config)
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
def roi_init(self, config):
self.roi_align = SingleRoIExtractor(config,
@ -267,6 +269,8 @@ class Faster_Rcnn_Resnet50(nn.Cell):
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
if self.device_type == "Ascend":
bboxes_all = self.cast(bboxes_all, mstype.float16)
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
rois = self.cast(rois, mstype.float32)

View File

@ -40,7 +40,7 @@ class DenseNoTranpose(nn.Cell):
if self.device_type == "Ascend":
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
output = self.bias_add(self.cast(self.matmul(x, weight), mstype.float32), self.bias)
output = self.bias_add(self.matmul(x, weight), self.bias)
else:
output = self.bias_add(self.matmul(x, self.weight), self.bias)
return output

View File

@ -16,7 +16,7 @@
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import context, Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
@ -102,6 +102,7 @@ class RPN(nn.Cell):
cfg_rpn = config
self.dtype = np.float32
self.ms_type = mstype.float32
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = ()
self.feature_anchor_shape = ()
@ -180,9 +181,12 @@ class RPN(nn.Cell):
bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor()
for i in range(num_layers):
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
rpn_reg_cls_block = RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg))
bias_cls, weight_reg, bias_reg)
if self.device_type == "Ascend":
rpn_reg_cls_block.to_float(mstype.float16)
rpn_layer.append(rpn_reg_cls_block)
for i in range(1, num_layers):
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
@ -250,6 +254,7 @@ class RPN(nn.Cell):
mstype.bool_),
anchor_using_list, gt_valids_i)
bbox_target = self.cast(bbox_target, self.ms_type)
bbox_weight = self.cast(bbox_weight, self.ms_type)
label = self.cast(label, self.ms_type)
label_weight = self.cast(label_weight, self.ms_type)
@ -286,8 +291,8 @@ class RPN(nn.Cell):
label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
cls_score_i = rpn_cls_score[i]
reg_score_i = rpn_bbox_pred[i]
cls_score_i = self.cast(rpn_cls_score[i], self.ms_type)
reg_score_i = self.cast(rpn_bbox_pred[i], self.ms_type)
loss_cls = self.loss_cls(cls_score_i, label_)
loss_cls_item = loss_cls * label_weight_

View File

@ -152,6 +152,10 @@ if __name__ == '__main__':
param_dict[key] = Parameter(tensor, key)
load_param_into_net(net, param_dict)
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
if device_type == "Ascend":
net.to_float(mstype.float16)
loss = LossNet()
lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32)