From 28373bc877f56ffcfafe317ab028437c9cd0176b Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Thu, 12 Aug 2021 15:58:26 +0800 Subject: [PATCH] add argument to the backbone of yolo_resnet18 --- .../official/cv/yolov3_resnet18/src/yolov3.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/cv/yolov3_resnet18/src/yolov3.py b/model_zoo/official/cv/yolov3_resnet18/src/yolov3.py index 91ac4081e4b..f7ca11b0d7d 100644 --- a/model_zoo/official/cv/yolov3_resnet18/src/yolov3.py +++ b/model_zoo/official/cv/yolov3_resnet18/src/yolov3.py @@ -156,13 +156,17 @@ class ResNet(nn.Cell): in_channels, out_channels, strides=None, - num_classes=80): + num_classes=None, + feature_only=True): super(ResNet, self).__init__() if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: raise ValueError("the length of " "layer_num, inchannel, outchannel list must be 4!") + self.feature_only = feature_only + if num_classes is None: + self.feature_only = True self.conv1 = _conv2d(3, 64, 7, stride=2) self.bn1 = _fused_bn(64) self.relu = P.ReLU() @@ -240,7 +244,7 @@ class ResNet(nn.Cell): c5 = self.layer4(c4) out = c5 - if self.num_classes: + if self.feature_only: out = self.reduce_mean(c5, (2, 3)) out = self.squeeze(out) out = self.end_point(out) @@ -266,7 +270,8 @@ def resnet18(class_num=10): [64, 64, 128, 256], [64, 128, 256, 512], [1, 2, 2, 2], - num_classes=class_num) + num_classes=class_num, + feature_only=False) class YoloBlock(nn.Cell): @@ -586,7 +591,8 @@ class yolov3_resnet18(nn.Cell): self.config.backbone_input_shape, self.config.backbone_shape, self.config.backbone_stride, - num_classes=None), + num_classes=None, + feature_only=True), backbone_shape=self.config.backbone_shape, out_channel=self.config.out_channel)