!10978 modified two parameters to one parameter in yolov3_darknet53 network

From: @shuzigood
Reviewed-by: @wuxuejian,@linqingke
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-01-06 09:56:38 +08:00 committed by Gitee
commit cfda3a49b6
4 changed files with 13 additions and 13 deletions

View File

@ -24,11 +24,9 @@ import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mindspore import Tensor
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms
from src.yolo import YOLOV3DarkNet53
from src.logger import get_logger
@ -297,7 +295,6 @@ def test():
# init detection engine
detection = DetectionEngine(args)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
args.logger.info('Start inference....')
for i, data in enumerate(ds.create_dict_iterator(num_epochs=1)):
image = data["image"]
@ -305,7 +302,7 @@ def test():
image_shape = data["image_shape"]
image_id = data["img_id"]
prediction = network(image, input_shape)
prediction = network(image)
output_big, output_me, output_small = prediction
output_big = output_big.asnumpy()
output_me = output_me.asnumpy()
@ -324,7 +321,7 @@ def test():
eval_result = detection.get_eval_result()
cost_time = time.time() - start_time
args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
args.logger.info('\n=============coco eval result=========\n' + eval_result)
args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.))

View File

@ -47,6 +47,5 @@ if __name__ == "__main__":
shape = [args.batch_size, 3] + config.test_img_shape
input_data = Tensor(np.zeros(shape), ms.float32)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format)
export(network, input_data, file_name=args.file_name, file_format=args.file_format)

View File

@ -365,6 +365,7 @@ class YOLOV3DarkNet53(nn.Cell):
def __init__(self, is_training):
super(YOLOV3DarkNet53, self).__init__()
self.config = ConfigYOLOV3DarkNet53()
self.tenser_to_array = P.TupleToArray()
# YOLOv3 network
self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers,
@ -379,7 +380,9 @@ class YOLOV3DarkNet53(nn.Cell):
self.detect_2 = DetectionBlock('m', is_training=is_training)
self.detect_3 = DetectionBlock('s', is_training=is_training)
def construct(self, x, input_shape):
def construct(self, x):
input_shape = F.shape(x)[2:4]
input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32)
big_object_output, medium_object_output, small_object_output = self.feature_map(x)
output_big = self.detect_1(big_object_output, input_shape)
output_me = self.detect_2(medium_object_output, input_shape)
@ -394,12 +397,15 @@ class YoloWithLossCell(nn.Cell):
super(YoloWithLossCell, self).__init__()
self.yolo_network = network
self.config = ConfigYOLOV3DarkNet53()
self.tenser_to_array = P.TupleToArray()
self.loss_big = YoloLossBlock('l', self.config)
self.loss_me = YoloLossBlock('m', self.config)
self.loss_small = YoloLossBlock('s', self.config)
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape):
yolo_out = self.yolo_network(x, input_shape)
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2):
input_shape = F.shape(x)[2:4]
input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32)
yolo_out = self.yolo_network(x)
loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)

View File

@ -26,7 +26,6 @@ from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
import mindspore as ms
from mindspore import amp
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
@ -254,9 +253,8 @@ def train():
batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])
input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
batch_gt_box2, input_shape)
batch_gt_box2)
loss_meter.update(loss.asnumpy())
if args.rank_save_ckpt_flag: