forked from mindspore-Ecosystem/mindspore
!17957 maskrcnn repair
From: @huchunmei Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian,@c_34
This commit is contained in:
commit
f86d707126
|
@ -52,7 +52,8 @@ def eval_alexnet():
|
|||
network = AlexNet(config.num_classes, phase='test')
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
opt = nn.Momentum(network.trainable_params(), config.learning_rate, config.momentum)
|
||||
ds_eval = create_dataset_cifar10(config, config.data_path, config.batch_size, target=config.device_target)
|
||||
ds_eval = create_dataset_cifar10(config, config.data_path, config.batch_size, status="test", \
|
||||
target=config.device_target)
|
||||
param_dict = load_checkpoint(config.ckpt_path)
|
||||
print("load checkpoint from [{}].".format(config.ckpt_path))
|
||||
load_param_into_net(network, param_dict)
|
||||
|
|
|
@ -28,7 +28,7 @@ def test_maskrcnn_export():
|
|||
export maskrcnn air.
|
||||
"""
|
||||
old_list = ["(config=config)", "(net, param_dict_new)"]
|
||||
new_list = ["(config=config\\n) '''", "(net, param_dict_new)\\n '''"]
|
||||
new_list = ["(config=config)\\n '''", "(net, param_dict_new)\\n '''"]
|
||||
|
||||
cur_path = os.getcwd()
|
||||
model_path = "{}/../../../../model_zoo/official/cv".format(cur_path)
|
||||
|
|
Loading…
Reference in New Issue