forked from mindspore-Ecosystem/mindspore
!6731 Amend deeplabv3 hub config
Merge pull request !6731 from jiangzhenguang/amend_deeplabv3_hub_config
This commit is contained in:
commit
ebb8700875
|
@ -17,11 +17,11 @@ from src.nets import net_factory
|
|||
|
||||
def create_network(name, *args, **kwargs):
|
||||
freeze_bn = True
|
||||
num_classes = 21
|
||||
num_classes = kwargs["num_classes"]
|
||||
if name == 'deeplab_v3_s16':
|
||||
deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"]('eval', num_classes, 16, freeze_bn)
|
||||
return deeplab_v3_s16_network(*args, **kwargs)
|
||||
return deeplab_v3_s16_network
|
||||
if name == 'deeplab_v3_s8':
|
||||
deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn)
|
||||
return deeplab_v3_s8_network(*args, **kwargs)
|
||||
return deeplab_v3_s8_network
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
Loading…
Reference in New Issue