amend deeplabv3 hub config

This commit is contained in:
jzg 2020-09-22 19:51:55 +08:00
parent 9475f9a19a
commit a32c5fbc92
1 changed files with 3 additions and 3 deletions

View File

@ -17,11 +17,11 @@ from src.nets import net_factory
def create_network(name, *args, **kwargs):
freeze_bn = True
num_classes = 21
num_classes = kwargs["num_classes"]
if name == 'deeplab_v3_s16':
deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"]('eval', num_classes, 16, freeze_bn)
return deeplab_v3_s16_network(*args, **kwargs)
return deeplab_v3_s16_network
if name == 'deeplab_v3_s8':
deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn)
return deeplab_v3_s8_network(*args, **kwargs)
return deeplab_v3_s8_network
raise NotImplementedError(f"{name} is not implemented in the repo")