forked from OSSInnovation/mindspore
fix mobilenetv2 script error
This commit is contained in:
parent
03ddf421ca
commit
a7e881f312
|
@ -28,7 +28,7 @@ from src.utils import switch_precision, set_context
|
|||
if __name__ == '__main__':
|
||||
args_opt = eval_parse_args()
|
||||
config = set_config(args_opt)
|
||||
backbone_net, head_net, net = define_net(config)
|
||||
backbone_net, head_net, net = define_net(config, args_opt.is_training)
|
||||
|
||||
#load the trained checkpoint file to the net for evaluation
|
||||
if args_opt.head_ckpt:
|
||||
|
|
|
@ -119,9 +119,11 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True):
|
|||
for param in network.get_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def define_net(config):
|
||||
def define_net(config, is_training):
|
||||
backbone_net = MobileNetV2Backbone()
|
||||
activation = config.activation if not args.is_training else "None"
|
||||
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
|
||||
net = mobilenet_v2(backbone_net, head_net, activation=activation)
|
||||
activation = config.activation if not is_training else "None"
|
||||
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,
|
||||
num_classes=config.num_classes,
|
||||
activation=activation)
|
||||
net = mobilenet_v2(backbone_net, head_net)
|
||||
return backbone_net, head_net, net
|
||||
|
|
|
@ -51,7 +51,7 @@ if __name__ == '__main__':
|
|||
context_device_init(config)
|
||||
|
||||
# define network
|
||||
backbone_net, head_net, net = define_net(config)
|
||||
backbone_net, head_net, net = define_net(config, args_opt.is_training)
|
||||
|
||||
# load the ckpt file to the network for fine tune or incremental leaning
|
||||
if args_opt.pretrain_ckpt:
|
||||
|
|
Loading…
Reference in New Issue