amend deeplabv3 export.py

This commit is contained in:
jiangzhenguang 2021-02-24 16:10:58 +08:00
parent d4abe53f34
commit 8a7b85ec24
3 changed files with 4 additions and 3 deletions

View File

@ -496,7 +496,7 @@ Note: There OS is output stride, and MS is multiscale.
| Loss Function | Softmax Cross Entropy |
| Outputs | probability |
| Loss | 0.0065883575 |
| Speed | 60 ms/step1pc, s16<br> 480 ms/step8pcs, s16 <br> 244 ms/step (8pcs, s8) |
| Speed | 60 fps1pc, s16<br> 480 fps8pcs, s16 <br> 244 fps (8pcs, s8) |
| Total time | 8pcs: 706 mins |
| Parameters (M) | 58.2 |
| Checkpoint for Fine tuning | 443M (.ckpt file) |

View File

@ -510,7 +510,7 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.0065883575 |
| 速度 | 31毫秒/步单卡s8<br> 234毫秒/步8卡s8 |
| 速度 | 31 帧数/秒单卡s8<br> 234 帧数/秒8卡s8 |
| 微调检查点 | 443M .ckpt文件 |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3) |

View File

@ -17,7 +17,7 @@ import argparse
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from eval import BuildEvalNetwork
from src.nets import net_factory
parser = argparse.ArgumentParser(description='checkpoint export')
@ -43,6 +43,7 @@ if __name__ == '__main__':
network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True)
else:
network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True)
network = BuildEvalNetwork(network)
param_dict = load_checkpoint(args.ckpt_file)
# load the parameter into net