add argument to the backbone of yolo_resnet18

This commit is contained in:
chenhaozhe 2021-08-12 15:58:26 +08:00
parent bff5cbda9c
commit 28373bc877
1 changed files with 10 additions and 4 deletions

View File

@ -156,13 +156,17 @@ class ResNet(nn.Cell):
in_channels,
out_channels,
strides=None,
num_classes=80):
num_classes=None,
feature_only=True):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
self.feature_only = feature_only
if num_classes is None:
self.feature_only = True
self.conv1 = _conv2d(3, 64, 7, stride=2)
self.bn1 = _fused_bn(64)
self.relu = P.ReLU()
@ -240,7 +244,7 @@ class ResNet(nn.Cell):
c5 = self.layer4(c4)
out = c5
if self.num_classes:
if self.feature_only:
out = self.reduce_mean(c5, (2, 3))
out = self.squeeze(out)
out = self.end_point(out)
@ -266,7 +270,8 @@ def resnet18(class_num=10):
[64, 64, 128, 256],
[64, 128, 256, 512],
[1, 2, 2, 2],
num_classes=class_num)
num_classes=class_num,
feature_only=False)
class YoloBlock(nn.Cell):
@ -586,7 +591,8 @@ class yolov3_resnet18(nn.Cell):
self.config.backbone_input_shape,
self.config.backbone_shape,
self.config.backbone_stride,
num_classes=None),
num_classes=None,
feature_only=True),
backbone_shape=self.config.backbone_shape,
out_channel=self.config.out_channel)