forked from mindspore-Ecosystem/mindspore
!21736 Add argument `feature_only` to the backbone of yolov3_resnet18
Merge pull request !21736 from chenhaozhe/modify-yolov3-resnet18-backbone
This commit is contained in:
commit
f92db4fa09
|
@ -156,13 +156,17 @@ class ResNet(nn.Cell):
|
|||
in_channels,
|
||||
out_channels,
|
||||
strides=None,
|
||||
num_classes=80):
|
||||
num_classes=None,
|
||||
feature_only=True):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of "
|
||||
"layer_num, inchannel, outchannel list must be 4!")
|
||||
|
||||
self.feature_only = feature_only
|
||||
if num_classes is None:
|
||||
self.feature_only = True
|
||||
self.conv1 = _conv2d(3, 64, 7, stride=2)
|
||||
self.bn1 = _fused_bn(64)
|
||||
self.relu = P.ReLU()
|
||||
|
@ -240,7 +244,7 @@ class ResNet(nn.Cell):
|
|||
c5 = self.layer4(c4)
|
||||
|
||||
out = c5
|
||||
if self.num_classes:
|
||||
if self.feature_only:
|
||||
out = self.reduce_mean(c5, (2, 3))
|
||||
out = self.squeeze(out)
|
||||
out = self.end_point(out)
|
||||
|
@ -266,7 +270,8 @@ def resnet18(class_num=10):
|
|||
[64, 64, 128, 256],
|
||||
[64, 128, 256, 512],
|
||||
[1, 2, 2, 2],
|
||||
num_classes=class_num)
|
||||
num_classes=class_num,
|
||||
feature_only=False)
|
||||
|
||||
|
||||
class YoloBlock(nn.Cell):
|
||||
|
@ -586,7 +591,8 @@ class yolov3_resnet18(nn.Cell):
|
|||
self.config.backbone_input_shape,
|
||||
self.config.backbone_shape,
|
||||
self.config.backbone_stride,
|
||||
num_classes=None),
|
||||
num_classes=None,
|
||||
feature_only=True),
|
||||
backbone_shape=self.config.backbone_shape,
|
||||
out_channel=self.config.out_channel)
|
||||
|
||||
|
|
Loading…
Reference in New Issue