forked from mindspore-Ecosystem/mindspore
Remove model_zoo from mindspore
This commit is contained in:
parent
2969688382
commit
6c62abd19f
|
@ -1,4 +0,0 @@
|
|||
approvers:
|
||||
- c_34
|
||||
-gengdongjie
|
||||
-zhao_ting_v
|
|
@ -1,126 +1,10 @@
|
|||
# ![MindSpore Logo](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
|
||||
|
||||
## Welcome to the Model Zoo for MindSpore
|
||||
|
||||
In order to facilitate developers to enjoy the benefits of MindSpore framework, we will continue to add typical networks and some of the related pre-trained models. If you have needs for the model zoo, you can file an issue on [gitee](https://gitee.com/mindspore/mindspore/issues) or [MindSpore](https://bbs.huaweicloud.com/forum/forum-1076-1.html), We will consider it in time.
|
||||
|
||||
- SOTA models using the latest MindSpore APIs
|
||||
|
||||
- The best benefits from MindSpore
|
||||
|
||||
- Officially maintained and supported
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Official](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official)
|
||||
- [Computer Vision](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [Image Classification](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [DenseNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/densenet/README.md)
|
||||
- [GoogleNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet/README.md)
|
||||
- [ResNet50[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet/README.md)
|
||||
- [ResNet50_Quant](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/resnet50_quant/README.md)
|
||||
- [ResNet101](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet/README.md)
|
||||
- [ResNext50](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnext/README.md)
|
||||
- [VGG16](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/vgg16/README.md)
|
||||
- [AlexNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet/README.md)
|
||||
- [LeNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet/README.md)
|
||||
- [LeNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet_quant/Readme.md)
|
||||
- [MobileNetV2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv2/README.md)
|
||||
- [MobileNetV2_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv2_quant/Readme.md)
|
||||
- [MobileNetV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv3/Readme.md)
|
||||
- [InceptionV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv3/README.md)
|
||||
- [InceptionV4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4/README.md)
|
||||
- [Object Detection and Segmentation](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [DeepLabV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3/README.md)
|
||||
- [FasterRCNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/faster_rcnn/README.md)
|
||||
- [YoloV3-DarkNet53](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53/README.md)
|
||||
- [YoloV3-ResNet18](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_resnet18/README.md)
|
||||
- [MaskRCNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/maskrcnn/README.md)
|
||||
- [Unet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet/README.md)
|
||||
- [SSD](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd/README.md)
|
||||
- [Warp-CTC](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/warpctc/README.md)
|
||||
- [RetinaFace-ResNet50](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/retinaface_resnet50/README.md)
|
||||
- [Keypoint Detection](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [OpenPose](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/openpose/README.md)
|
||||
|
||||
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp)
|
||||
- [BERT[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md)
|
||||
- [TinyBERT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/tinybert/README.md)
|
||||
- [GNMT V2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2/README.md)
|
||||
- [FastText](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext/README.md)
|
||||
- [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md)
|
||||
- [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md)
|
||||
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
|
||||
- [CPM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/cpm/README.md)
|
||||
- [Recommender Systems](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend)
|
||||
- [DeepFM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm/README.md)
|
||||
- [NAML](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/naml/README.md)
|
||||
- [Wide&Deep[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/wide_and_deep/README.md)
|
||||
- [Graph Neural Networks](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn)
|
||||
- [BGCF](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf/README.md)
|
||||
- [GAT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gat/README.md)
|
||||
- [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn//README.md)
|
||||
|
||||
- [Research](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research)
|
||||
- [Computer Vision](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv)
|
||||
- [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/Readme.md)
|
||||
- [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md)
|
||||
- [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md)
|
||||
- [CycleGAN](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/research/cv/CycleGAN/README.md)
|
||||
- [FaceAttribute](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceAttribute/README.md)
|
||||
- [FaceDetection](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceDetection/README.md)
|
||||
- [FaceQualityAssessment](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceQualityAssessment/README.md)
|
||||
- [FaceRecognitionForTracking](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceRecognitionForTracking/README.md)
|
||||
- [FaceRecognition](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceRecognition/README.md)
|
||||
- [CenterNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet/README.md)
|
||||
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
|
||||
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
|
||||
- [Recommender Systems](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend)
|
||||
- [AutoDis](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis/README.md)
|
||||
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
|
||||
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
|
||||
- [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md)
|
||||
- [Wavenet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/wavenet/README.md)
|
||||
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
|
||||
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
|
||||
- [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)
|
||||
- [PINNs](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/research/hpc/pinns/README.md)
|
||||
- [Community](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/community)
|
||||
![MindSpore Logo](https://gitee.com/mindspore/mindspore/raw/master/docs/MindSpore-logo.png "MindSpore logo")
|
||||
|
||||
## Announcements
|
||||
|
||||
| Date | News |
|
||||
| ------------ | ------------------------------------------------------------ |
|
||||
| September 25, 2020 | Support [MindSpore v1.0.0](https://www.mindspore.cn/news/newschildren/en?id=262) |
|
||||
| September 01, 2020 | Support [MindSpore v0.7.0-beta](https://www.mindspore.cn/news/newschildren/en?id=246) |
|
||||
| July 31, 2020 | Support [MindSpore v0.6.0-beta](https://www.mindspore.cn/news/newschildren/en?id=237) |
|
||||
**Modelzoo has been moved to the new repository [models](https://gitee.com/mindspore/models)**.
|
||||
|
||||
## Related Website
|
||||
For more detail about separation of model_zoo, you could refer to the following issues:
|
||||
|
||||
Here is the ModelZoo for MindSpore which support different devices including Ascend, GPU, CPU and mobile.
|
||||
|
||||
If you are looking for exclusive models only for Ascend using different ML platform, you could refer to [Ascend ModelZoo](https://hiascend.com/software/modelzoo) and corresponding [gitee repository](https://gitee.com/ascend/modelzoo)
|
||||
|
||||
Modelzoo will be transferred to a new repo [models](https://gitee.com/mindspore/models).
|
||||
|
||||
## Disclaimers
|
||||
|
||||
Mindspore only provides scripts that downloads and preprocesses public datasets. We do not own these datasets and are not responsible for their quality or maintenance. Please make sure you have permission to use the dataset under the dataset’s license. The models trained on these dataset are for non-commercial research and educational purpose only.
|
||||
|
||||
To dataset owners: we will remove or update all public content upon request if you don’t want your dataset included on Mindspore, or wish to update it in any way. Please contact us through a Github/Gitee issue. Your understanding and contribution to this community is greatly appreciated.
|
||||
|
||||
MindSpore is Apache 2.0 licensed. Please see the LICENSE file.
|
||||
|
||||
## License
|
||||
|
||||
[Apache License 2.0](https://gitee.com/mindspore/mindspore/blob/master/LICENSE)
|
||||
|
||||
## FAQ
|
||||
|
||||
- **Q: How to resolve the lack of memory while using `PYNATIVE_MODE` with errors such as *Failed to alloc memory pool memory*?**
|
||||
|
||||
**A**: `PYNATIVE_MODE` usually requires more memory than `GRAPH_MODE`, especially in training process which have to deal with back propagation. You could try using smaller batch size.
|
||||
|
||||
- **Q: How to resolve the error about the interface not supported, such as `cann not import`?**
|
||||
|
||||
**A**: Please check the version of MindSpore and the branch you fetch the modelzoo scripts. Some model scripits in latest branch will use new interface in the latest version of MindSpore.
|
||||
- https://gitee.com/mindspore/mindspore/issues/I49LM0
|
||||
- https://gitee.com/mindspore/mindspore/issues/I3TZ87
|
||||
|
|
|
@ -1,126 +1,10 @@
|
|||
# ![MindSpore Logo](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
|
||||
|
||||
## 欢迎来到MindSpore ModelZoo
|
||||
|
||||
为了让开发者更好地体验MindSpore框架优势,我们将陆续增加更多的典型网络和相关预训练模型。如果您对ModelZoo有任何需求,请通过[Gitee](https://gitee.com/mindspore/mindspore/issues)或[MindSpore](https://bbs.huaweicloud.com/forum/forum-1076-1.html)与我们联系,我们将及时处理。
|
||||
|
||||
- 使用最新MindSpore API的SOTA模型
|
||||
|
||||
- MindSpore优势
|
||||
|
||||
- 官方维护和支持
|
||||
|
||||
## 目录
|
||||
|
||||
- [官方](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official)
|
||||
- [计算机视觉](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [图像分类](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [DenseNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/densenet/README.md)
|
||||
- [GoogleNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet/README.md)
|
||||
- [ResNet-50[基准]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet/README.md)
|
||||
- [ResNet50_Quant](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/resnet50_quant/README.md)
|
||||
- [ResNet-101](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet/README.md)
|
||||
- [ResNeXt-50](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnext/README_CN.md)
|
||||
- [VGG16](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/vgg16/README.md)
|
||||
- [AlexNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet/README.md)
|
||||
- [LeNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet/README.md)
|
||||
- [LeNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet_quant/Readme.md)
|
||||
- [MobileNetV2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv2/README.md)
|
||||
- [MobileNetV2_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv2_quant/Readme.md)
|
||||
- [MobileNetV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv3/Readme.md)
|
||||
- [InceptionV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv3/README.md)
|
||||
- [InceptionV4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4/README.md)
|
||||
- [目标检测与分割](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [DeepLabV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3/README.md)
|
||||
- [Faster R-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/faster_rcnn/README.md)
|
||||
- [YoloV3-DarkNet53](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53/README.md)
|
||||
- [YoloV3-ResNet18](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_resnet18/README.md)
|
||||
- [Mask R-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/maskrcnn/README.md)
|
||||
- [U-Net](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet/README.md)
|
||||
- [SSD](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd/README.md)
|
||||
- [Warp-CTC](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/warpctc/README.md)
|
||||
- [RetinaFace-ResNet50](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/retinaface_resnet50/README.md)
|
||||
- [关键点检测](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [OpenPose](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/openpose/README.md)
|
||||
|
||||
- [自然语言处理](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp)
|
||||
- [BERT[基准]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md)
|
||||
- [TinyBERT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/tinybert/README.md)
|
||||
- [GNMT V2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2/README.md)
|
||||
- [FastText](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext/README.md)
|
||||
- [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md)
|
||||
- [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md)
|
||||
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
|
||||
- [CPM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/cpm/README.md)
|
||||
- [推荐系统](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend)
|
||||
- [DeepFM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm/README.md)
|
||||
- [NAML](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/naml/README.md)
|
||||
- [Wide&Deep[基准]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/wide_and_deep/README.md)
|
||||
- [图神经网络](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn)
|
||||
- [BGCF](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf/README.md)
|
||||
- [GAT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gat/README.md)
|
||||
- [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn//README.md)
|
||||
|
||||
- [研究](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research)
|
||||
- [计算机视觉](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv)
|
||||
- [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/Readme.md)
|
||||
- [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md)
|
||||
- [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md)
|
||||
- [CycleGAN](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/research/cv/CycleGAN/README.md)
|
||||
- [人脸属性](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceAttribute/README.md)
|
||||
- [人脸检测](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceDetection/README.md)
|
||||
- [人脸图像质量评估](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceQualityAssessment/README.md)
|
||||
- [人脸识别跟踪](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceRecognitionForTracking/README.md)
|
||||
- [人脸识别](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceRecognition/README.md)
|
||||
- [CenterNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet/README.md)
|
||||
- [自然语言处理](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
|
||||
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
|
||||
- [推荐系统](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend)
|
||||
- [AutoDis](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis/README.md)
|
||||
- [语音](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
|
||||
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
|
||||
- [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md)
|
||||
- [Wavenet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/wavenet/README.md)
|
||||
- [高性能计算](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
|
||||
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
|
||||
- [分子动力学](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)
|
||||
- [PINNs](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/research/hpc/pinns/README.md)
|
||||
- [社区](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/community)
|
||||
![MindSpore Logo](https://gitee.com/mindspore/mindspore/raw/master/docs/MindSpore-logo.png "MindSpore logo")
|
||||
|
||||
## 公告
|
||||
|
||||
|日期|新闻|
|
||||
| ------------ | ------------------------------------------------------------ |
|
||||
| 2020年9月25日|支持[MindSpore v1.0.0](https://www.mindspore.cn/news/newschildren/en?id=262) |
|
||||
| 2020年9月1日|支持[MindSpore v0.7.0-beta](https://www.mindspore.cn/news/newschildren/en?id=246) |
|
||||
| 2020年7月31日|支持[MindSpore v0.6.0-beta](https://www.mindspore.cn/news/newschildren/en?id=237) |
|
||||
**model_zoo已经被转移到一个独立的新仓库[models](https://gitee.com/mindspore/models)**。
|
||||
|
||||
## 关联站点
|
||||
更多关于model_zoo分离的信息,你可以参看以下issue:
|
||||
|
||||
这里是MindSpore框架提供的可以运行于包括Ascend/GPU/CPU/移动设备等多种设备的模型库。
|
||||
|
||||
相应的专属于Ascend平台的多框架模型可以参考[昇腾ModelZoo](https://hiascend.com/software/modelzoo)以及对应的[代码仓](https://gitee.com/ascend/modelzoo)。
|
||||
|
||||
modelzoo将被转移到一个独立的新仓库[models](https://gitee.com/mindspore/models)。
|
||||
|
||||
## 免责声明
|
||||
|
||||
MindSpore仅提供下载和预处理公共数据集的脚本。我们不拥有这些数据集,也不对它们的质量负责或维护。请确保您具有在数据集许可下使用该数据集的权限。在这些数据集上训练的模型仅用于非商业研究和教学目的。
|
||||
|
||||
致数据集拥有者:如果您不希望将数据集包含在MindSpore中,或者希望以任何方式对其进行更新,我们将根据要求删除或更新所有公共内容。请通过GitHub或Gitee与我们联系。非常感谢您对这个社区的理解和贡献。
|
||||
|
||||
MindSpore已获得Apache 2.0许可,请参见LICENSE文件。
|
||||
|
||||
## 许可证
|
||||
|
||||
[Apache 2.0许可证](https://gitee.com/mindspore/mindspore/blob/master/LICENSE)
|
||||
|
||||
## FAQ
|
||||
|
||||
- **Q: 使用`PYNATIVE_MODE`运行模型出现错误内存不足,例如*Failed to alloc memory pool memory*, 该怎么处理?**
|
||||
|
||||
**A**: `PYNATIVE_MODE`通常比`GRAPH_MODE`使用更多内存,尤其是在需要进行反向传播计算的训练图中,你可以尝试使用一些更小的batch size.
|
||||
|
||||
- **Q: 一些网络运行中报错接口不存在,例如cannot import,该怎么处理?**
|
||||
|
||||
**A**: 优先检查一下获取网络脚本的分支,与所使用的的MindSpore版本是否一致,部分新分支中的模型脚本会使用一些新版本MindSpore才支持的借口,从而在使用老版本MindSpore时会发生报错.
|
||||
- https://gitee.com/mindspore/mindspore/issues/I49LM0
|
||||
- https://gitee.com/mindspore/mindspore/issues/I3TZ87
|
||||
|
|
|
@ -1,135 +0,0 @@
|
|||
# 如何贡献MindSpore ModelZoo
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [如何贡献MindSpore ModelZoo](#如何贡献mindspore-modelzoo)
|
||||
- [准备工作](#准备工作)
|
||||
- [了解贡献协议与流程](#了解贡献协议与流程)
|
||||
- [确定自己贡献的目标](#确定自己贡献的目标)
|
||||
- [代码提交](#代码提交)
|
||||
- [CodeStyle](#codestyle)
|
||||
- [目录结构](#目录结构)
|
||||
- [ReadMe 说明](#readme-说明)
|
||||
- [关于第三方引用](#关于第三方引用)
|
||||
- [引用额外的python库](#引用额外的python库)
|
||||
- [引用第三方开源代码](#引用第三方开源代码)
|
||||
- [引用其他系统库](#引用其他系统库)
|
||||
- [提交自检列表](#提交自检列表)
|
||||
- [维护与交流](#维护与交流)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
本指导用于明确ModelZoo贡献规范,从而确保众多的开发者能够以一种相对统一的风格和流程参与到ModelZoo的建设中。
|
||||
|
||||
## 准备工作
|
||||
|
||||
### 了解贡献协议与流程
|
||||
|
||||
你应该优先参考MindSpore的[CONTRIBUTE.md](../CONTRIBUTING.md)说明来理解MindSpore的开源协议和运作方式,并确保自己已完成CLA的签署。
|
||||
|
||||
<!--
|
||||
### 确定自己贡献的目标
|
||||
|
||||
如果希望进行贡献,我们推荐你先从一些较为容易的issue开始尝试。你可以从以下列表中寻找一些简单的例如bugfix的工作。
|
||||
|
||||
- [wanted bugfix](https://gitee.com/mindspore/mindspore/issues?assignee_id=&author_id=&branch=&issue_search=&label_ids=58021213&label_text=kind/bug&milestone_id=&program_id=&scope=&sort=newest&state=open)
|
||||
|
||||
如果你可以完成独立的网络贡献,你可以从以下列表中找到我们待实现的网络列表。
|
||||
|
||||
- [wanted implement](https://gitee.com/mindspore/mindspore/issues?assignee_id=&author_id=&branch=&issue_search=&label_ids=58022151&label_text=device%2Fascend&milestone_id=&program_id=&scope=&sort=newest&state=open)
|
||||
|
||||
> **Notice** 记得在选定issue之后进行一条回复,从而让别人知道你已经着手于此issue的工作。当你完成某项工作后,也记得回到issue更新你的成果。如果过程中有什么问题,也可以随时在issue中更新你的进展。
|
||||
-->
|
||||
|
||||
## 代码提交
|
||||
|
||||
### CodeStyle
|
||||
|
||||
参考[CONTRIBUTE.md](../CONTRIBUTING.md)中关于CodeStyle的说明,你应该确保自己的代码与MindSpore的现有代码风格保持一致。
|
||||
|
||||
### 目录结构
|
||||
|
||||
为了保证ModelZoo中的实现能够提供一种相对统一的使用方法,我们提供了一种基础的**目录结构模板**,你应该基于此结构来组织自己的工程。
|
||||
|
||||
```shell
|
||||
model_zoo
|
||||
├── official # 官方支持模型
|
||||
│ └── XXX # 模型名
|
||||
│ ├── README.md # 模型说明文档
|
||||
│ ├── requirements.txt # 依赖说明文件
|
||||
│ ├── eval.py # 精度验证脚本
|
||||
│ ├── export.py # 推理模型导出脚本
|
||||
│ ├── scripts # 脚本文件
|
||||
│ │ ├── run_distributed_train.sh # 分布式训练脚本
|
||||
│ │ ├── run_eval.sh # 验证脚本
|
||||
│ │ └── run_standalone_train.sh # 单机训练脚本
|
||||
│ ├── src # 模型定义源码目录
|
||||
│ │ ├── XXXNet.py # 模型结构定义
|
||||
│ │ ├── callback.py # 回调函数定义
|
||||
│ │ ├── config.py # 模型配置参数文件
|
||||
│ │ └── dataset.py # 数据集处理定义
|
||||
│ ├── ascend_infer # (可选)用于在Ascend推理设备上进行离线推理的脚本
|
||||
│ ├── third_party # (可选)第三方代码
|
||||
│ │ └── XXXrepo # (可选)完整克隆自第三方仓库的代码
|
||||
│ └── train.py # 训练脚本
|
||||
├── research # 非官方研究脚本
|
||||
├── community # 合作方脚本链接
|
||||
└── utils # 模型通用工具
|
||||
```
|
||||
|
||||
你可以参照以下原则,根据自己的需要在模板基础上做一些适配自己实现的修改
|
||||
|
||||
1. 模型根目录下只放置带有`main方法`的可执行脚本,模型的定义文件统一放在`src`目录下,该目录下可以根据自己模型的复杂程度自行组织层次结构。
|
||||
2. 配置参数应当与网络定义分离,将所有可配置的参数抽离到`src/config.py`文件中统一定义。
|
||||
3. 上传内容应当只包含脚本、代码和文档,**不要上传**任何数据集或checkpoint之类的数据文件。
|
||||
4. third_party用于存放需要引用的第三方代码,但是不要直接将代码拷贝到目录下上传,而应该使用git链接的形式,在使用时下载。
|
||||
5. 每个模型的代码应当自成闭包,可以独立的迁移使用,不应当依赖模型目录以外的其他代码。utils内只是通用工具,并非通用函数库。
|
||||
6. 上传内容中**不要包含**任何你的个人信息,例如你的主机IP,个人密码,本地目录等。
|
||||
|
||||
### ReadMe 说明
|
||||
|
||||
每个AI模型都需要一个对应的`README.md`作为说明文档,对当前的模型实现进行介绍,从而向其他用户传递以下信息:
|
||||
|
||||
1. 这是个什么模型?来源和参考是什么?
|
||||
2. 当前的实现包含哪些内容?
|
||||
3. 如何使用现有的实现?
|
||||
4. 这个模型表现如何?
|
||||
|
||||
对此,我们提供了一个基础的[README模版](./README_template.md),你应该参考此模版来完善自己的说明文档, 也可以参考其他现有模型的readme。
|
||||
|
||||
### 关于第三方引用
|
||||
|
||||
#### 引用额外的python库
|
||||
|
||||
确保将自己所需要的额外python库和对应版本(如果有明确要求)注明在`requirements.txt`文件。你应该优先选择和MindSpore框架兼容的第三方库。
|
||||
|
||||
#### 引用第三方开源代码
|
||||
|
||||
你应该保证所提交的代码是自己原创开发所完成的。
|
||||
|
||||
当你需要借助一些开源社区的力量,应当优先引用一些成熟可信的开源项目,同时确认自己所选择的开源项目所使用的开源协议是否符合要求。
|
||||
|
||||
当你使用开源代码时,正确的使用方式是通过git地址获取对应代码,并在使用中将对应代码归档在独立的`third_party`目录中,保持与自己的代码隔离。**切勿粗暴的拷贝对应代码片段到自己的提交中。**
|
||||
|
||||
#### 引用其他系统库
|
||||
|
||||
你应该减少对一些独特系统库的依赖,因为这通常意味着你的提交在不同系统中难以复用。
|
||||
|
||||
当你确实需要使用一些独特的系统依赖来完成任务时,你需要在说明中指出对应的获取和安装方法。
|
||||
|
||||
### 提交自检列表
|
||||
|
||||
你所提交的代码应该经过充分的Review, 可以参考以下checklist进行自查
|
||||
|
||||
- [ ] 代码风格符合规范
|
||||
- [ ] 代码在必要的位置添加了注释
|
||||
- [ ] 文档已同步修改
|
||||
- [ ] 同步添加了必要的测试用例
|
||||
- [ ] 所有第三方依赖都已经说明,包括代码引用,python库,数据集,预训练模型等
|
||||
- [ ] 工程组织结构符合[目录结构](#目录结构)中的要求。
|
||||
|
||||
## 维护与交流
|
||||
|
||||
我们十分感谢您对MindSpore社区的贡献,同时十分希望您能够在完成一次提交之后持续关注您所提交的代码。 您可以在所提交模型的README中标注您的署名与常用邮箱等联系方式,并持续关注您的gitee、github信息。
|
||||
|
||||
其他的开发者也许会用到您所提交的模型,使用期间可能会产生一些疑问,此时就可以通过issue、站内信息、邮件等方式与您进行详细的交流.
|
|
@ -1,99 +0,0 @@
|
|||
<TOC>
|
||||
|
||||
# Title, Model name
|
||||
|
||||
> The Description of Model. The paper present this model.
|
||||
|
||||
## Model Architecture
|
||||
|
||||
> There could be various architecture about some model. Represent the architecture of your implementation.
|
||||
|
||||
## Features(optional)
|
||||
|
||||
> Represent the distinctive feature you used in the model implementation. Such as distributed auto-parallel or some special training trick.
|
||||
|
||||
## Dataset
|
||||
|
||||
> Provide the information of the dataset you used. Check the copyrights of the dataset you used, usually don't provide the hyperlink to download the dataset.
|
||||
|
||||
## Requirements
|
||||
|
||||
> Provide details of the software required, including:
|
||||
>
|
||||
> * The additional python package required. Add a `requirements.txt` file to the root dir of model for installing dependencies.
|
||||
> * The necessary third-party code.
|
||||
> * Some other system dependencies.
|
||||
> * Some additional operations before training or prediction.
|
||||
|
||||
## Quick Start
|
||||
|
||||
> How to take a try without understanding anything about the model.
|
||||
|
||||
## Script Description
|
||||
|
||||
> The section provide the detail of implementation.
|
||||
|
||||
### Scripts and Sample Code
|
||||
|
||||
> Explain every file in your project.
|
||||
|
||||
### Script Parameter
|
||||
|
||||
> Explain every parameter of the model. Especially the parameters in `config.py`.
|
||||
|
||||
## Training
|
||||
|
||||
> Provide training information.
|
||||
|
||||
### Training Process
|
||||
|
||||
> Provide the usage of training scripts.
|
||||
|
||||
e.g. Run the following command for distributed training on Ascend.
|
||||
|
||||
```shell
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL]
|
||||
```
|
||||
|
||||
### Transfer Training(Optional)
|
||||
|
||||
> Provide the guidelines about how to run transfer training based on an pretrained model.
|
||||
|
||||
### Training Result
|
||||
|
||||
> Provide the result of training.
|
||||
|
||||
e.g. Training checkpoint will be stored in `XXXX/ckpt_0`. You will get result from log file like the following:
|
||||
|
||||
```
|
||||
epoch: 11 step: 7393 ,rpn_loss: 0.02003, rcnn_loss: 0.52051, rpn_cls_loss: 0.01761, rpn_reg_loss: 0.00241, rcnn_cls_loss: 0.16028, rcnn_reg_loss: 0.08411, rcnn_mask_loss: 0.27588, total_loss: 0.54054
|
||||
epoch: 12 step: 7393 ,rpn_loss: 0.00547, rcnn_loss: 0.39258, rpn_cls_loss: 0.00285, rpn_reg_loss: 0.00262, rcnn_cls_loss: 0.08002, rcnn_reg_loss: 0.04990, rcnn_mask_loss: 0.26245, total_loss: 0.39804
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Evaluation Process
|
||||
|
||||
> Provide the use of evaluation scripts.
|
||||
|
||||
### Evaluation Result
|
||||
|
||||
> Provide the result of evaluation.
|
||||
|
||||
## Performance
|
||||
|
||||
### Training Performance
|
||||
|
||||
> Provide the detail of training performance including finishing loss, throughput, checkpoint size and so on.
|
||||
|
||||
### Inference Performance
|
||||
|
||||
> Provide the detail of evaluation performance including latency, accuracy and so on.
|
||||
|
||||
## Description of Random Situation
|
||||
|
||||
> Explain the random situation in the project.
|
||||
|
||||
## ModeZoo Homepage
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -1,270 +0,0 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [DeepSort描述](#DeepSort描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [导出mindir模型](#导出mindir模型)
|
||||
- [推理过程](#推理过程)
|
||||
- [用法](#用法)
|
||||
- [结果](#结果)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## DeepSort描述
|
||||
|
||||
DeepSort是2017年提出的多目标跟踪算方法。该网络在MOT16获得冠军,不仅提升了精度,而且速度比之前快20倍。
|
||||
|
||||
[论文](https://arxiv.org/abs/1602.00763): Nicolai Wojke, Alex Bewley, Dietrich Paulus. "SIMPLE ONLINE AND REALTIME TRACKING WITH A DEEP ASSOCIATION METRIC". *Presented at ICIP 2016*.
|
||||
|
||||
## 模型架构
|
||||
|
||||
DeepSort由一个特征提取器、一个卡尔曼滤波和一个匈牙利算法组成。特征提取器用于提取框中人物特征信息,卡尔曼滤波根据上一帧信息预测当前帧人物位置,匈牙利算法用于匹配预测信息与检测到的人物位置信息。
|
||||
|
||||
## 数据集
|
||||
|
||||
使用的数据集:[MOT16](<https://motchallenge.net/data/MOT16.zip>)、[Market-1501](<https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view>)
|
||||
|
||||
MOT16:
|
||||
|
||||
- 数据集大小:1.9G,共14个视频帧序列
|
||||
- test:7个视频序列帧
|
||||
- train:7个序列帧
|
||||
- 数据格式(一个train视频帧序列):
|
||||
- det:视频序列中人物坐标以及置信度等信息
|
||||
- gt:视频跟踪标签信息
|
||||
- img1:视频中所有帧序列
|
||||
- 注意:由于作者提供的视频帧序列检测到的坐标信息和置信度信息不一样,所以在跟踪时使用作者提供的信息,作者提供的[npy](https://drive.google.com/drive/folders/18fKzfqnqhqW3s9zwsCbnVJ5XF2JFeqMp)文件。
|
||||
|
||||
Market-1501:
|
||||
|
||||
- 使用:
|
||||
- 使用目的:训练DeepSort特征提取器
|
||||
- 使用方法: 先使用prepare.py处理数据
|
||||
|
||||
## 环境要求
|
||||
|
||||
- 硬件(Ascend/ModelArts)
|
||||
- 准备Ascend或ModelArts处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
## 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```python
|
||||
# 进入脚本目录,提取det信息(使用作者提供的检测框信息),在脚本中给出数据路径
|
||||
python process-npy.py
|
||||
# 进入脚本目录,预处理数据集(Market-1501),在脚本中给出数据集路径
|
||||
python prepare.py
|
||||
# 进入脚本目录,训练DeepSort特征提取器
|
||||
python src/deep/train.py --run_modelarts=False --run_distribute=True --data_url="" --train_url=""
|
||||
# 进入脚本目录,提取detections信息
|
||||
python generater_detection.py --run_modelarts=False --run_distribute=True --data_url="" --train_url="" --det_url="" --ckpt_url="" --model_name=""
|
||||
# 进入脚本目录,生成跟踪信息
|
||||
python evaluate_motchallenge.py --data_url="" --train_url="" --detection_url=""
|
||||
|
||||
#Ascend多卡训练
|
||||
bash scripts/run_distribute_train.sh train_code_path RANK_TABLE_FILE DATA_PATH
|
||||
```
|
||||
|
||||
Ascend训练:生成[RANK_TABLE_FILE](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── DeepSort
|
||||
├── scripts
|
||||
│ ├──run_distribute_train.sh // 在Ascend中多卡训练
|
||||
├── src //源码
|
||||
│ │ ├── application_util
|
||||
│ │ │ ├──image_viewer.py
|
||||
│ │ │ ├──preprocessing.py
|
||||
│ │ │ ├──visualization.py
|
||||
│ │ ├──deep
|
||||
│ │ │ ├──feature_extractor.py //提取目标框中人物特征信息
|
||||
│ │ │ ├──original_model.py //特征提取器模型
|
||||
│ │ │ ├──train.py //训练网络模型
|
||||
│ │ ├──sort
|
||||
│ │ │ ├──detection.py
|
||||
│ │ │ ├──iou_matching.py //预测信息与真实框匹配
|
||||
│ │ │ ├──kalman_filter.py //卡尔曼滤波,预测跟踪框信息
|
||||
│ │ │ ├──linear_assignment.py
|
||||
│ │ │ ├──nn_matching.py //框匹配
|
||||
│ │ │ ├──track.py //跟踪器
|
||||
│ │ │ ├──tracker.py //跟踪器
|
||||
├── deep_sort_app.py //目标跟踪
|
||||
├── evaluate_motchallenge.py //生成跟踪结果信息
|
||||
├── generate_videos.py //根据跟踪结果生成跟踪视频
|
||||
├── generater-detection.py //生成detection信息
|
||||
├── postprocess.py //生成Ascend310推理数据
|
||||
├── preprocess.py //处理Ascend310推理结果,生成精度
|
||||
├── prepare.py //处理训练数据集
|
||||
├── process-npy.py //提取帧序列人物坐标和置信度
|
||||
├── show_results.py //展示跟踪结果
|
||||
├── README.md // DeepSort相关说明
|
||||
```
|
||||
|
||||
### 脚本参数
|
||||
|
||||
```python
|
||||
train.py generater_detection.py evaluate_motchallenge.py 中主要参数如下:
|
||||
|
||||
--data_url: 到训练和提取信息数据集的绝对完整路径
|
||||
--train_url: 输出文件路径。
|
||||
--epoch: 总训练轮次
|
||||
--batch_size: 训练批次大小
|
||||
--device_targe: 实现代码的设备。值为'Ascend'
|
||||
--ckpt_url: 训练后保存的检查点文件的绝对完整路径
|
||||
--model_name: 模型文件名称
|
||||
--det_url: 视频帧序列人物信息文件路径
|
||||
--detection_url: 人物坐标信息、置信度以及特征信息文件路径
|
||||
--run_distribute: 多卡运行
|
||||
--run_modelarts: ModelArts上运行
|
||||
```
|
||||
|
||||
### 训练过程
|
||||
|
||||
#### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
python src/deep/train.py --run_modelarts=False --run_distribute=False --data_url="" --train_url=""
|
||||
# 或进入脚本目录,执行脚本
|
||||
bash scripts/run_distribute_train.sh train_code_path RANK_TABLE_FILE DATA_PATH
|
||||
```
|
||||
|
||||
经过训练后,损失值如下:
|
||||
|
||||
```bash
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 3984, loss is 6.4320717
|
||||
epoch: 1 step: 3984, loss is 6.414733
|
||||
epoch: 1 step: 3984, loss is 6.4306755
|
||||
epoch: 1 step: 3984, loss is 6.4387856
|
||||
epoch: 1 step: 3984, loss is 6.463995
|
||||
...
|
||||
epoch: 2 step: 3984, loss is 6.436552
|
||||
epoch: 2 step: 3984, loss is 6.408932
|
||||
epoch: 2 step: 3984, loss is 6.4517527
|
||||
epoch: 2 step: 3984, loss is 6.448922
|
||||
epoch: 2 step: 3984, loss is 6.4611588
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在当前目录下。
|
||||
|
||||
### 评估过程
|
||||
|
||||
#### 评估
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# 进入脚本目录,提取det信息(使用作者提供的检测框信息)
|
||||
python process-npy.py
|
||||
# 进入脚本目录,提取detections信息
|
||||
python generater_detection.py --run_modelarts False --run_distribute True --data_url "" --train_url "" --det_url "" --ckpt_url "" --model_name ""
|
||||
# 进入脚本目录,生成跟踪信息
|
||||
python evaluate_motchallenge.py --data_url="" --train_url="" --detection_url=""
|
||||
# 生成跟踪结果
|
||||
python eval_motchallenge.py ----run_modelarts=False --data_url="" --train_url="" --result_url=""
|
||||
```
|
||||
|
||||
- [测评工具](https://github.com/cheind/py-motmetrics)
|
||||
|
||||
说明:脚本中引用头文件可能存在一些问题,自行修改头文件路径即可
|
||||
|
||||
```bash
|
||||
#测量精度
|
||||
python motmetrics/apps/eval_motchallenge.py --groundtruths="" --tests=""
|
||||
```
|
||||
|
||||
-
|
||||
测试数据集的准确率如下:
|
||||
|
||||
| 数据 | MOTA | MOTP| MT | ML| IDs | FM | FP | FN |
|
||||
| -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -----------------------------------------------------------
|
||||
| MOT16-02 | 29.0% | 0.207 | 11 | 11| 159 | 226 | 4151 | 8346 |
|
||||
| MOT16-04 | 58.6% | 0.167| 42 | 14| 62 | 242 | 6269 | 13374 |
|
||||
| MOT16-05 | 51.7% | 0.213| 31 | 27| 68 | 109 | 630 | 2595 |
|
||||
| MOT16-09 | 64.3% | 0.162| 12 | 1| 39 | 58 | 309 | 1537 |
|
||||
| MOT16-10 | 49.2% | 0.228| 25 | 1| 201 | 307 | 3089 | 2915 |
|
||||
| MOT16-11 | 65.9% | 0.152| 29 | 9| 54 | 99 | 907 | 2162 |
|
||||
| MOT16-13 | 45.0% | 0.237| 61 | 7| 269 | 335 | 3709 | 2251 |
|
||||
| overall | 51.9% | 0.189| 211 | 70| 852 | 1376 | 19094 | 33190 |
|
||||
|
||||
## [导出mindir模型](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --device_id [DEVICE_ID] --ckpt_file [CKPT_PATH]
|
||||
```
|
||||
|
||||
## [推理过程](#contents)
|
||||
|
||||
### 用法
|
||||
|
||||
执行推断之前,minirir文件必须由export.py导出。输入文件必须为bin格式
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
推理结果文件保存在当前路径中,将文件作为输入,输入到eval_motchallenge.py中,然后输出result文件,输入到测评工具中即可得到精度结果。
|
||||
|
||||
## 模型描述
|
||||
|
||||
### 性能
|
||||
|
||||
#### 评估性能
|
||||
|
||||
| 参数 | ModelArts
|
||||
| -------------------------- | -----------------------------------------------------------
|
||||
| 资源 | Ascend 910;CPU 2.60GHz, 192核;内存:755G
|
||||
| 上传日期 | 2021-08-12
|
||||
| MindSpore版本 | 1.2.0
|
||||
| 数据集 | MOT16 Market-1501
|
||||
| 训练参数 | epoch=100, step=191, batch_size=8, lr=0.1
|
||||
| 优化器 | SGD
|
||||
| 损失函数 | SoftmaxCrossEntropyWithLogits
|
||||
| 损失 | 0.03
|
||||
| 速度 | 9.804毫秒/步
|
||||
| 总时间 | 10分钟
|
||||
| 微调检查点 | 大约40M (.ckpt文件)
|
||||
| 脚本 | [DeepSort脚本]
|
||||
|
||||
## 随机情况说明
|
||||
|
||||
train.py中设置了随机种子。
|
||||
|
||||
## ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -1,14 +0,0 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -1,29 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -1,32 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
#endif
|
|
@ -1,134 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "/inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(input0_path, ".", "input0 path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||
|
||||
Model model;
|
||||
Status ret = model.Build(GraphCell(graph), context);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> model_inputs = model.GetInputs();
|
||||
if (model_inputs.empty()) {
|
||||
std::cout << "Invalid model, inputs is empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input0_files = GetAllFiles(FLAGS_input0_path);
|
||||
|
||||
if (input0_files.empty()) {
|
||||
std::cout << "ERROR: input data empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = input0_files.size();
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
|
||||
|
||||
auto input0 = ReadFileToTensor(input0_files[i]);
|
||||
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
input0.Data().get(), input0.DataSize());
|
||||
|
||||
for (auto shape : model_inputs[0].Shape()) {
|
||||
std::cout << "model input shape" << shape << std::endl;
|
||||
}
|
||||
gettimeofday(&start, nullptr);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
WriteResult(input0_files[i], outputs);
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -1,130 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
std::cout << "output size:" << outputSize << std::endl;
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -1,275 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.application_util import preprocessing
|
||||
from src.application_util import visualization
|
||||
from src.sort import nn_matching
|
||||
from src.sort.detection import Detection
|
||||
from src.sort.tracker import Tracker
|
||||
|
||||
def gather_sequence_info(sequence_dir, detection_file):
|
||||
"""Gather sequence information, such as image filenames, detections,
|
||||
groundtruth (if available).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
detection_file : str
|
||||
Path to the detection file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict
|
||||
A dictionary of the following sequence information:
|
||||
|
||||
* sequence_name: Name of the sequence
|
||||
* image_filenames: A dictionary that maps frame indices to image
|
||||
filenames.
|
||||
* detections: A numpy array of detections in MOTChallenge format.
|
||||
* groundtruth: A numpy array of ground truth in MOTChallenge format.
|
||||
* image_size: Image size (height, width).
|
||||
* min_frame_idx: Index of the first frame.
|
||||
* max_frame_idx: Index of the last frame.
|
||||
|
||||
"""
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
groundtruth_file = os.path.join(sequence_dir, "gt/gt.txt")
|
||||
|
||||
detections = None
|
||||
if detection_file is not None:
|
||||
detections = np.load(detection_file)
|
||||
groundtruth = None
|
||||
if os.path.exists(groundtruth_file):
|
||||
groundtruth = np.loadtxt(groundtruth_file, delimiter=',')
|
||||
|
||||
if image_filenames:
|
||||
image = cv2.imread(next(iter(image_filenames.values())),
|
||||
cv2.IMREAD_GRAYSCALE)
|
||||
image_size = image.shape
|
||||
else:
|
||||
image_size = None
|
||||
|
||||
if image_filenames:
|
||||
min_frame_idx = min(image_filenames.keys())
|
||||
max_frame_idx = max(image_filenames.keys())
|
||||
else:
|
||||
min_frame_idx = int(detections[:, 0].min())
|
||||
max_frame_idx = int(detections[:, 0].max())
|
||||
|
||||
info_filename = os.path.join(sequence_dir, "seqinfo.ini")
|
||||
if os.path.exists(info_filename):
|
||||
with open(info_filename, "r") as f:
|
||||
line_splits = [l.split('=') for l in f.read().splitlines()[1:]]
|
||||
info_dict = dict(
|
||||
s for s in line_splits if isinstance(s, list) and len(s) == 2)
|
||||
|
||||
update_ms = 1000 / int(info_dict["frameRate"])
|
||||
else:
|
||||
update_ms = None
|
||||
|
||||
feature_dim = detections.shape[1] - 10 if detections is not None else 0
|
||||
seq_info = {
|
||||
"sequence_name": os.path.basename(sequence_dir),
|
||||
"image_filenames": image_filenames,
|
||||
"detections": detections,
|
||||
"groundtruth": groundtruth,
|
||||
"image_size": image_size,
|
||||
"min_frame_idx": min_frame_idx,
|
||||
"max_frame_idx": max_frame_idx,
|
||||
"feature_dim": feature_dim,
|
||||
"update_ms": update_ms
|
||||
}
|
||||
return seq_info
|
||||
|
||||
|
||||
def create_detections(detection_mat, frame_idx, min_height=0):
|
||||
"""Create detections for given frame index from the raw detection matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_mat : ndarray
|
||||
Matrix of detections. The first 10 columns of the detection matrix are
|
||||
in the standard MOTChallenge detection format. In the remaining columns
|
||||
store the feature vector associated with each detection.
|
||||
frame_idx : int
|
||||
The frame index.
|
||||
min_height : Optional[int]
|
||||
A minimum detection bounding box height. Detections that are smaller
|
||||
than this value are disregarded.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[tracker.Detection]
|
||||
Returns detection responses at given frame index.
|
||||
|
||||
"""
|
||||
frame_indices = detection_mat[:, 0].astype(np.int)
|
||||
mask = frame_indices == frame_idx
|
||||
|
||||
detection_list = []
|
||||
for row in detection_mat[mask]:
|
||||
bbox, confidence, feature = row[2:6], row[6], row[10:]
|
||||
|
||||
if bbox[3] < min_height:
|
||||
continue
|
||||
detection_list.append(Detection(bbox, confidence, feature))
|
||||
return detection_list
|
||||
|
||||
|
||||
def run(sequence_dir, detection_file, output_file, min_confidence,
|
||||
nms_max_overlap, min_detection_height, max_cosine_distance,
|
||||
nn_budget, display):
|
||||
"""Run multi-target tracker on a particular sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
detection_file : str
|
||||
Path to the detections file.
|
||||
output_file : str
|
||||
Path to the tracking output file. This file will contain the tracking
|
||||
results on completion.
|
||||
min_confidence : float
|
||||
Detection confidence threshold. Disregard all detections that have
|
||||
a confidence lower than this value.
|
||||
nms_max_overlap: float
|
||||
Maximum detection overlap (non-maxima suppression threshold).
|
||||
min_detection_height : int
|
||||
Detection height threshold. Disregard all detections that have
|
||||
a height lower than this value.
|
||||
max_cosine_distance : float
|
||||
Gating threshold for cosine distance metric (object appearance).
|
||||
nn_budget : Optional[int]
|
||||
Maximum size of the appearance descriptor gallery. If None, no budget
|
||||
is enforced.
|
||||
display : bool
|
||||
If True, show visualization of intermediate tracking results.
|
||||
|
||||
"""
|
||||
seq_info = gather_sequence_info(sequence_dir, detection_file)
|
||||
metric = nn_matching.NearestNeighborDistanceMetric(
|
||||
"cosine", max_cosine_distance, nn_budget)
|
||||
tracker = Tracker(metric)
|
||||
results = []
|
||||
|
||||
def frame_callback(vis, frame_idx):
|
||||
print("Processing frame %05d" % frame_idx)
|
||||
|
||||
# Load image and generate detections.
|
||||
detections = create_detections(
|
||||
seq_info["detections"], frame_idx, min_detection_height)
|
||||
detections = [d for d in detections if d.confidence >= min_confidence]
|
||||
|
||||
# Run non-maxima suppression.
|
||||
boxes = np.array([d.tlwh for d in detections])
|
||||
|
||||
scores = np.array([d.confidence for d in detections])
|
||||
indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
|
||||
detections = [detections[i] for i in indices]
|
||||
|
||||
# Update tracker.
|
||||
tracker.predict()
|
||||
tracker.update(detections)
|
||||
|
||||
# Update visualization.
|
||||
if display:
|
||||
image = cv2.imread(
|
||||
seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR)
|
||||
vis.set_image(image.copy())
|
||||
vis.draw_detections(detections)
|
||||
vis.draw_trackers(tracker.tracks)
|
||||
|
||||
# Store results.
|
||||
for track in tracker.tracks:
|
||||
if not track.is_confirmed() or track.time_since_update > 1:
|
||||
continue
|
||||
bbox = track.to_tlwh()
|
||||
results.append([
|
||||
frame_idx, track.track_id, bbox[0], bbox[1], bbox[2], bbox[3]])
|
||||
|
||||
# Run tracker.
|
||||
if display:
|
||||
visualizer = visualization.Visualization(seq_info, update_ms=5)
|
||||
else:
|
||||
visualizer = visualization.NoVisualization(seq_info)
|
||||
visualizer.run(frame_callback)
|
||||
|
||||
# Store results.
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for row in results:
|
||||
print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1' % (
|
||||
row[0], row[1], row[2], row[3], row[4], row[5]), file=f)
|
||||
|
||||
|
||||
def bool_string(input_string):
|
||||
if input_string not in {"True", "False"}:
|
||||
raise ValueError("Please Enter a valid Ture/False choice")
|
||||
return input_string == "True"
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Deep SORT")
|
||||
parser.add_argument(
|
||||
"--sequence_dir", help="Path to MOTChallenge sequence directory",
|
||||
default="../MOT16/train/MOT16-02")
|
||||
parser.add_argument(
|
||||
"--detection_file", help="Path to custom detections.", default="./detections/MOT16_POI_train/MOT16-02.npy")
|
||||
parser.add_argument(
|
||||
"--output_file", help="Path to the tracking output file. This file will"
|
||||
" contain the tracking results on completion.",
|
||||
default="./tmp/hypotheses-det.txt")
|
||||
parser.add_argument(
|
||||
"--min_confidence", help="Detection confidence threshold. Disregard "
|
||||
"all detections that have a confidence lower than this value.",
|
||||
default=0.8, type=float)
|
||||
parser.add_argument(
|
||||
"--min_detection_height", help="Threshold on the detection bounding "
|
||||
"box height. Detections with height smaller than this value are "
|
||||
"disregarded", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--nms_max_overlap", help="Non-maxima suppression threshold: Maximum "
|
||||
"detection overlap.", default=1.0, type=float)
|
||||
parser.add_argument(
|
||||
"--max_cosine_distance", help="Gating threshold for cosine distance "
|
||||
"metric (object appearance).", type=float, default=0.2)
|
||||
parser.add_argument(
|
||||
"--nn_budget", help="Maximum size of the appearance descriptors "
|
||||
"gallery. If None, no budget is enforced.", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--display", help="Show intermediate tracking results",
|
||||
default=False, type=bool_string)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
run(
|
||||
args.sequence_dir, args.detection_file, args.output_file,
|
||||
args.min_confidence, args.nms_max_overlap, args.min_detection_height,
|
||||
args.max_cosine_distance, args.nn_budget, args.display)
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import deep_sort_app
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="MOTChallenge evaluation")
|
||||
parser.add_argument(
|
||||
"--detection_url", type=str, help="Path to detection files.")
|
||||
parser.add_argument(
|
||||
"--data_url", type=str, help="Path to image data.")
|
||||
parser.add_argument(
|
||||
"--train_url", type=str, help="Path to save result.")
|
||||
parser.add_argument(
|
||||
"--min_confidence", help="Detection confidence threshold. Disregard "
|
||||
"all detections that have a confidence lower than this value.",
|
||||
default=0.0, type=float)
|
||||
parser.add_argument(
|
||||
"--min_detection_height", help="Threshold on the detection bounding "
|
||||
"box height. Detections with height smaller than this value are "
|
||||
"disregarded", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--nms_max_overlap", help="Non-maxima suppression threshold: Maximum "
|
||||
"detection overlap.", default=1.0, type=float)
|
||||
parser.add_argument(
|
||||
"--max_cosine_distance", help="Gating threshold for cosine distance "
|
||||
"metric (object appearance).", type=float, default=0.2)
|
||||
parser.add_argument(
|
||||
"--nn_budget", help="Maximum size of the appearance descriptors "
|
||||
"gallery. If None, no budget is enforced.", type=int, default=100)
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
detection_dir = args.detection_url
|
||||
DATA_DIR = args.data_url + '/'
|
||||
local_result_url = args.train_url
|
||||
|
||||
if not os.path.exists(local_result_url):
|
||||
os.makedirs(local_result_url)
|
||||
sequences = os.listdir(DATA_DIR)
|
||||
for sequence in sequences:
|
||||
print("Running sequence %s" % sequence)
|
||||
sequence_dir = os.path.join(DATA_DIR, sequence)
|
||||
detection_file = os.path.join(detection_dir, "%s.npy" % sequence)
|
||||
output_file = os.path.join(local_result_url, "%s.txt" % sequence)
|
||||
deep_sort_app.run(
|
||||
sequence_dir, detection_file, output_file, args.min_confidence,
|
||||
args.nms_max_overlap, args.min_detection_height,
|
||||
args.max_cosine_distance, args.nn_budget, display=False)
|
|
@ -1,50 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.deep.original_model import Net
|
||||
|
||||
parser = argparse.ArgumentParser(description='Tracking')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--image_height", type=int, default=128, help="Image height.")
|
||||
parser.add_argument("--image_width", type=int, default=64, help="Image width.")
|
||||
parser.add_argument("--file_name", type=str, default="deepsort", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = Net(reid=True, ascend=True)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args_opt.batch_size, 3, args_opt.image_height, args_opt.image_width]), ms.float32)
|
||||
export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import show_results
|
||||
|
||||
def convert(filename_input, filename_output, ffmpeg_executable="ffmpeg"):
|
||||
import subprocess
|
||||
command = [ffmpeg_executable, "-i", filename_input, "-c:v", "libx264",
|
||||
"-preset", "slow", "-crf", "21", filename_output]
|
||||
subprocess.call(command)
|
||||
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Siamese Tracking")
|
||||
parser.add_argument('--data_url', type=str, default='', help='Det directory.')
|
||||
parser.add_argument('--train_url', type=str, help='Folder to store the videos in')
|
||||
parser.add_argument(
|
||||
"--result_dir", help="Path to the folder with tracking output.", default="")
|
||||
parser.add_argument(
|
||||
"--convert_h264", help="If true, convert videos to libx264 (requires "
|
||||
"FFMPEG", default=False)
|
||||
parser.add_argument(
|
||||
"--update_ms", help="Time between consecutive frames in milliseconds. "
|
||||
"Defaults to the frame_rate specified in seqinfo.ini, if available.",
|
||||
default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
data_dir = args.data_url
|
||||
local_train_url = args.train_url
|
||||
result_dir = args.result_dir
|
||||
|
||||
|
||||
os.makedirs(local_train_url, exist_ok=True)
|
||||
for sequence_txt in os.listdir(result_dir):
|
||||
sequence = os.path.splitext(sequence_txt)[0]
|
||||
sequence_dir = os.path.join(data_dir, sequence)
|
||||
if not os.path.exists(sequence_dir):
|
||||
continue
|
||||
result_file = os.path.join(result_dir, sequence_txt)
|
||||
update_ms = args.update_ms
|
||||
video_filename = os.path.join(local_train_url, "%s.avi" % sequence)
|
||||
|
||||
print("Saving %s to %s." % (sequence_txt, video_filename))
|
||||
show_results.run(
|
||||
sequence_dir, result_file, False, None, update_ms, video_filename)
|
||||
|
||||
if not args.convert_h264:
|
||||
import sys
|
||||
sys.exit()
|
||||
for sequence_txt in os.listdir(result_dir):
|
||||
sequence = os.path.splitext(sequence_txt)[0]
|
||||
sequence_dir = os.path.join(data_dir, sequence)
|
||||
if not os.path.exists(sequence_dir):
|
||||
continue
|
||||
filename_in = os.path.join(local_train_url, "%s.avi" % sequence)
|
||||
filename_out = os.path.join(local_train_url, "%s.mp4" % sequence)
|
||||
convert(filename_in, filename_out)
|
|
@ -1,259 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import errno
|
||||
import argparse
|
||||
import ast
|
||||
import matplotlib
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init
|
||||
from mindspore import context
|
||||
from src.deep.feature_extractor import Extractor
|
||||
|
||||
matplotlib.use("Agg")
|
||||
ASCEND_SLOG_PRINT_TO_STDOUT = 1
|
||||
|
||||
|
||||
def extract_image_patch(image, bbox, patch_shape=None):
|
||||
"""Extract image patch from bounding box.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : ndarray
|
||||
The full image.
|
||||
bbox : array_like
|
||||
The bounding box in format (x, y, width, height).
|
||||
patch_shape : Optional[array_like]
|
||||
This parameter can be used to enforce a desired patch shape
|
||||
(height, width). First, the `bbox` is adapted to the aspect ratio
|
||||
of the patch shape, then it is clipped at the image boundaries.
|
||||
If None, the shape is computed from :arg:`bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray | NoneType
|
||||
An image patch showing the :arg:`bbox`, optionally reshaped to
|
||||
:arg:`patch_shape`.
|
||||
Returns None if the bounding box is empty or fully outside of the image
|
||||
boundaries.
|
||||
|
||||
"""
|
||||
bbox = np.array(bbox)
|
||||
if patch_shape is not None:
|
||||
# correct aspect ratio to patch shape
|
||||
target_aspect = float(patch_shape[1]) / patch_shape[0]
|
||||
new_width = target_aspect * bbox[3]
|
||||
bbox[0] -= (new_width - bbox[2]) / 2
|
||||
bbox[2] = new_width
|
||||
|
||||
# convert to top left, bottom right
|
||||
bbox[2:] += bbox[:2]
|
||||
bbox = bbox.astype(np.int)
|
||||
|
||||
# clip at image boundaries
|
||||
bbox[:2] = np.maximum(0, bbox[:2])
|
||||
bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:])
|
||||
if np.any(bbox[:2] >= bbox[2:]):
|
||||
return None
|
||||
sx, sy, ex, ey = bbox
|
||||
image = image[sy:ey, sx:ex]
|
||||
return image
|
||||
|
||||
|
||||
class ImageEncoder:
|
||||
|
||||
def __init__(self, model_path, batch_size=32):
|
||||
|
||||
self.extractor = Extractor(model_path, batch_size)
|
||||
|
||||
def _get_features(self, bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
self.height, self.width = ori_img.shape[:2]
|
||||
for box in bbox_xywh:
|
||||
im = extract_image_patch(ori_img, box)
|
||||
if im is None:
|
||||
print("WARNING: Failed to extract image patch: %s." % str(box))
|
||||
im = np.random.uniform(
|
||||
0., 255., ori_img.shape).astype(np.uint8)
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = self.extractor(im_crops)
|
||||
else:
|
||||
features = np.array([])
|
||||
return features
|
||||
|
||||
|
||||
def __call__(self, image, boxes, batch_size=32):
|
||||
features = self._get_features(boxes, image)
|
||||
return features
|
||||
|
||||
|
||||
def create_box_encoder(model_filename, batch_size=32):
|
||||
image_encoder = ImageEncoder(model_filename, batch_size)
|
||||
|
||||
def encoder_box(image, boxes):
|
||||
return image_encoder(image, boxes)
|
||||
|
||||
return encoder_box
|
||||
|
||||
|
||||
def generate_detections(encoder_boxes, mot_dir, output_dir, det_path=None, detection_dir=None):
|
||||
"""Generate detections with features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoder : Callable[image, ndarray] -> ndarray
|
||||
The encoder function takes as input a BGR color image and a matrix of
|
||||
bounding boxes in format `(x, y, w, h)` and returns a matrix of
|
||||
corresponding feature vectors.
|
||||
mot_dir : str
|
||||
Path to the MOTChallenge directory (can be either train or test).
|
||||
output_dir
|
||||
Path to the output directory. Will be created if it does not exist.
|
||||
detection_dir
|
||||
Path to custom detections. The directory structure should be the default
|
||||
MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the
|
||||
standard MOTChallenge detections.
|
||||
|
||||
"""
|
||||
if detection_dir is None:
|
||||
detection_dir = mot_dir
|
||||
try:
|
||||
os.makedirs(output_dir)
|
||||
except OSError as exception:
|
||||
if exception.errno == errno.EEXIST and os.path.isdir(output_dir):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Failed to created output directory '%s'" % output_dir)
|
||||
for sequence in os.listdir(mot_dir):
|
||||
print("Processing %s" % sequence)
|
||||
sequence_dir = os.path.join(mot_dir, sequence)
|
||||
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
#image_dir = os.path.join(mot_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
if det_path:
|
||||
detection_dir = os.path.join(det_path, sequence)
|
||||
else:
|
||||
detection_dir = os.path.join(sequence_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
detections_out = []
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
print("Frame %05d/%05d" % (frame_idx, max_frame_idx))
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
if frame_idx not in image_filenames:
|
||||
print("WARNING could not find image for frame %d" % frame_idx)
|
||||
continue
|
||||
bgr_image = cv2.imread(
|
||||
image_filenames[frame_idx], cv2.IMREAD_COLOR)
|
||||
features = encoder_boxes(bgr_image, rows[:, 2:6].copy())
|
||||
detections_out += [np.r_[(row, feature)]
|
||||
for row, feature in zip(rows, features)]
|
||||
|
||||
output_filename = os.path.join(output_dir, "%s.npy" % sequence)
|
||||
print(output_filename)
|
||||
np.save(
|
||||
output_filename, np.asarray(detections_out), allow_pickle=False)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Re-ID feature extractor")
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval,
|
||||
default=False, help='Run distribute')
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval,
|
||||
default=False, help='Run distribute')
|
||||
parser.add_argument("--device_id", type=int, default=4,
|
||||
help="Use which device.")
|
||||
parser.add_argument('--data_url', type=str,
|
||||
default='', help='Det directory.')
|
||||
parser.add_argument('--train_url', type=str, default='',
|
||||
help='Train output directory.')
|
||||
parser.add_argument('--det_url', type=str, default='',
|
||||
help='Train output directory.')
|
||||
parser.add_argument('--batch_size', type=int,
|
||||
default=32, help='Batach size.')
|
||||
parser.add_argument("--ckpt_url", type=str, default='',
|
||||
help="Path to checkpoint.")
|
||||
parser.add_argument("--model_name", type=str,
|
||||
default="deepsort-30000_24.ckpt", help="Name of checkpoint.")
|
||||
parser.add_argument(
|
||||
"--detection_dir", help="Path to custom detections. Defaults to", default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
args = parse_args()
|
||||
if args.run_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data'
|
||||
local_ckpt_url = '/cache/ckpt'
|
||||
local_train_url = '/cache/train'
|
||||
local_det_url = '/cache/det'
|
||||
mox.file.copy_parallel(args.ckpt_url, local_ckpt_url)
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
mox.file.copy_parallel(args.det_url, local_det_url)
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
DATA_DIR = local_data_url + '/'
|
||||
ckpt_dir = local_ckpt_url + '/'
|
||||
det_dir = local_det_url + '/'
|
||||
else:
|
||||
if args.run_distribute:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
device_num = 1
|
||||
device_id = args.device_id
|
||||
DATA_DIR = args.data_url
|
||||
local_train_url = args.train_url
|
||||
ckpt_dir = args.ckpt_url
|
||||
det_dir = args.det_url
|
||||
|
||||
encoder = create_box_encoder(
|
||||
ckpt_dir+args.model_name, batch_size=args.batch_size)
|
||||
generate_detections(encoder, DATA_DIR, local_train_url, det_path=det_dir)
|
||||
if args.run_modelarts:
|
||||
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore deepsort infer')
|
||||
# Path for data
|
||||
parser.add_argument('--det_dir', type=str, default='', help='det directory.')
|
||||
parser.add_argument('--result_dir', type=str, default="./result_Files", help='infer result dir.')
|
||||
parser.add_argument('--output_dir', type=str, default="./", help='output dir.')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
rst_path = args.result_dir
|
||||
start = end = 0
|
||||
|
||||
for sequence in os.listdir(args.det_dir):
|
||||
#sequence_dir = os.path.join(mot_dir, sequence)
|
||||
start = end
|
||||
detection_dir = os.path.join(args.det_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
detections_out = []
|
||||
raws = []
|
||||
features = []
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
for box in rows:
|
||||
raws.append(box)
|
||||
end += 1
|
||||
|
||||
raws = np.array(raws)
|
||||
for i in range(start, end):
|
||||
file_name = os.path.join(rst_path, "DeepSort_data_bs" + str(1) + '_' + str(i) + '_0.bin')
|
||||
output = np.fromfile(file_name, np.float32)
|
||||
features.append(output)
|
||||
features = np.array(features)
|
||||
detections_out += [np.r_[(row, feature)] for row, feature in zip(raws, features)]
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
output_filename = os.path.join(args.output_dir, "%s.npy" % sequence)
|
||||
print(output_filename)
|
||||
np.save(output_filename, np.asarray(detections_out), allow_pickle=False)
|
|
@ -1,122 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
from shutil import copyfile
|
||||
|
||||
# You only need to change this line to your dataset download path
|
||||
download_path = '../data/Market-1501'
|
||||
|
||||
if not os.path.isdir(download_path):
|
||||
print('please change the download_path')
|
||||
|
||||
save_path = download_path + '/pytorch'
|
||||
if not os.path.isdir(save_path):
|
||||
os.mkdir(save_path)
|
||||
#-----------------------------------------
|
||||
#query
|
||||
query_path = download_path + '/query'
|
||||
query_save_path = download_path + '/pytorch/query'
|
||||
if not os.path.isdir(query_save_path):
|
||||
os.mkdir(query_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(query_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = query_path + '/' + name
|
||||
dst_path = query_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#-----------------------------------------
|
||||
#multi-query
|
||||
query_path = download_path + '/gt_bbox'
|
||||
# for dukemtmc-reid, we do not need multi-query
|
||||
if os.path.isdir(query_path):
|
||||
query_save_path = download_path + '/pytorch/multi-query'
|
||||
if not os.path.isdir(query_save_path):
|
||||
os.mkdir(query_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(query_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = query_path + '/' + name
|
||||
dst_path = query_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#-----------------------------------------
|
||||
#gallery
|
||||
gallery_path = download_path + '/bounding_box_test'
|
||||
gallery_save_path = download_path + '/pytorch/gallery'
|
||||
if not os.path.isdir(gallery_save_path):
|
||||
os.mkdir(gallery_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(gallery_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = gallery_path + '/' + name
|
||||
dst_path = gallery_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#---------------------------------------
|
||||
#train_all
|
||||
train_path = download_path + '/bounding_box_train'
|
||||
train_save_path = download_path + '/pytorch/train_all'
|
||||
if not os.path.isdir(train_save_path):
|
||||
os.mkdir(train_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(train_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = train_path + '/' + name
|
||||
dst_path = train_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
|
||||
#---------------------------------------
|
||||
#train_val
|
||||
train_path = download_path + '/bounding_box_train'
|
||||
train_save_path = download_path + '/pytorch/train'
|
||||
val_save_path = download_path + '/pytorch/val'
|
||||
if not os.path.isdir(train_save_path):
|
||||
os.mkdir(train_save_path)
|
||||
os.mkdir(val_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(train_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = train_path + '/' + name
|
||||
dst_path = train_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
dst_path = val_save_path + '/' + ID[0] #first image is used as val image
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
|
@ -1,192 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
matplotlib.use("Agg")
|
||||
ASCEND_SLOG_PRINT_TO_STDOUT = 1
|
||||
|
||||
def extract_image_patch(image, bbox, patch_shape=None):
|
||||
"""Extract image patch from bounding box.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : ndarray
|
||||
The full image.
|
||||
bbox : array_like
|
||||
The bounding box in format (x, y, width, height).
|
||||
patch_shape : Optional[array_like]
|
||||
This parameter can be used to enforce a desired patch shape
|
||||
(height, width). First, the `bbox` is adapted to the aspect ratio
|
||||
of the patch shape, then it is clipped at the image boundaries.
|
||||
If None, the shape is computed from :arg:`bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray | NoneType
|
||||
An image patch showing the :arg:`bbox`, optionally reshaped to
|
||||
:arg:`patch_shape`.
|
||||
Returns None if the bounding box is empty or fully outside of the image
|
||||
boundaries.
|
||||
|
||||
"""
|
||||
bbox = np.array(bbox)
|
||||
if patch_shape is not None:
|
||||
# correct aspect ratio to patch shape
|
||||
target_aspect = float(patch_shape[1]) / patch_shape[0]
|
||||
new_width = target_aspect * bbox[3]
|
||||
bbox[0] -= (new_width - bbox[2]) / 2
|
||||
bbox[2] = new_width
|
||||
|
||||
# convert to top left, bottom right
|
||||
bbox[2:] += bbox[:2]
|
||||
bbox = bbox.astype(np.int)
|
||||
|
||||
# clip at image boundaries
|
||||
bbox[:2] = np.maximum(0, bbox[:2])
|
||||
bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:])
|
||||
if np.any(bbox[:2] >= bbox[2:]):
|
||||
return None
|
||||
sx, sy, ex, ey = bbox
|
||||
|
||||
image = image[sy:ey, sx:ex]
|
||||
return image
|
||||
|
||||
def statistic_normalize_img(img, statistic_norm=True):
|
||||
"""Statistic normalize images."""
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
if statistic_norm:
|
||||
img = (img - mean) / std
|
||||
img = img.astype(np.float32)
|
||||
return img
|
||||
|
||||
def preprocess(im_crops):
|
||||
"""
|
||||
TODO:
|
||||
1. to float with scale from 0 to 1
|
||||
2. resize to (64, 128) as Market1501 dataset did
|
||||
3. concatenate to a numpy array
|
||||
3. to torch Tensor
|
||||
4. normalize
|
||||
"""
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32)/255., size)
|
||||
im_batch = []
|
||||
size = (64, 128)
|
||||
for im in im_crops:
|
||||
im = _resize(im, size)
|
||||
im = statistic_normalize_img(im)
|
||||
im = im.transpose(2, 0, 1).copy()
|
||||
im = np.expand_dims(im, 0)
|
||||
im_batch.append(im)
|
||||
|
||||
im_batch = np.array(im_batch)
|
||||
return im_batch
|
||||
|
||||
|
||||
def get_features(bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
for box in bbox_xywh:
|
||||
im = extract_image_patch(ori_img, box)
|
||||
if im is None:
|
||||
print("WARNING: Failed to extract image patch: %s." % str(box))
|
||||
im = np.random.uniform(
|
||||
0., 255., ori_img.shape).astype(np.uint8)
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = preprocess(im_crops)
|
||||
else:
|
||||
features = np.array([])
|
||||
return features
|
||||
|
||||
def generate_detections(mot_dir, img_path, det_path=None):
|
||||
"""Generate detections with features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoder : Callable[image, ndarray] -> ndarray
|
||||
The encoder function takes as input a BGR color image and a matrix of
|
||||
bounding boxes in format `(x, y, w, h)` and returns a matrix of
|
||||
corresponding feature vectors.
|
||||
mot_dir : str
|
||||
Path to the MOTChallenge directory (can be either train or test).
|
||||
output_dir
|
||||
Path to the output directory. Will be created if it does not exist.
|
||||
detection_dir
|
||||
Path to custom detections. The directory structure should be the default
|
||||
MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the
|
||||
standard MOTChallenge detections.
|
||||
|
||||
"""
|
||||
|
||||
count = 0
|
||||
for sequence in os.listdir(mot_dir):
|
||||
print("Processing %s" % sequence)
|
||||
sequence_dir = os.path.join(mot_dir, sequence)
|
||||
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
#image_dir = os.path.join(mot_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
|
||||
if det_path is not None:
|
||||
detection_dir = os.path.join(det_path, sequence)
|
||||
else:
|
||||
detection_dir = os.path.join(sequence_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
print("Frame %05d/%05d" % (frame_idx, max_frame_idx))
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
if frame_idx not in image_filenames:
|
||||
print("WARNING could not find image for frame %d" % frame_idx)
|
||||
continue
|
||||
bgr_image = cv2.imread(image_filenames[frame_idx], cv2.IMREAD_COLOR)
|
||||
features = get_features(rows[:, 2:6].copy(), bgr_image)
|
||||
|
||||
for data in features:
|
||||
file_name = "DeepSort_data_bs" + str(1) + "_" + str(count) + ".bin"
|
||||
file_path = img_path + "/" + file_name
|
||||
data.tofile(file_path)
|
||||
count += 1
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Ascend 310 feature extractor")
|
||||
parser.add_argument('--data_path', type=str, default='', help='MOT directory.')
|
||||
parser.add_argument('--det_path', type=str, default='', help='Det directory.')
|
||||
parser.add_argument('--result_path', type=str, default='', help='Inference output directory.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
image_path = os.path.join(args.result_path, "00_data")
|
||||
os.mkdir(image_path)
|
||||
generate_detections(args.data_path, image_path, args.det_path)
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
npy_dir = "" #the npy files provided by author, and the directory name is MOT16_POI_train.
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(npy_dir):
|
||||
for filename in filenames:
|
||||
load_dir = os.path.join(dirpath, filename)
|
||||
loadData = np.load(load_dir)
|
||||
dirname = "./det/" + filename[ : 8] + "/" + "det/"
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
f = open(dirname+"det.txt", 'a')
|
||||
for info in loadData:
|
||||
s = ""
|
||||
for i, num in enumerate(info):
|
||||
if i in (0, 1, 7, 8, 9):
|
||||
s += str(int(num))
|
||||
if i != 9:
|
||||
s += ','
|
||||
elif i < 10:
|
||||
s += str(num)
|
||||
s += ','
|
||||
else:
|
||||
break
|
||||
#print(s)
|
||||
f.write(s)
|
||||
f.write('\n')
|
|
@ -1,5 +0,0 @@
|
|||
cv2
|
||||
mindspore
|
||||
numpy
|
||||
matplotlib
|
||||
shutil
|
|
@ -1,86 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [train_code_path][RANK_TABLE_FILE][DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
train_code_path=$(get_real_path $1)
|
||||
echo $train_code_path
|
||||
|
||||
if [ ! -d $train_code_path ]
|
||||
then
|
||||
echo "error: train_code_path=$train_code_path is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RANK_TABLE_FILE=$(get_real_path $2)
|
||||
echo $RANK_TABLE_FILE
|
||||
|
||||
if [ ! -f $RANK_TABLE_FILE ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATA_PATH=$(get_real_path $3)
|
||||
echo $DATA_PATH
|
||||
|
||||
if [ ! -d $DATA_PATH ]
|
||||
then
|
||||
echo "error: DATA_PATH=$DATA_PATH is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -c unlimited
|
||||
export SLOG_PRINT_TO_STDOUT=0
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_FILE
|
||||
export RANK_SIZE=8
|
||||
export RANK_START_ID=0
|
||||
|
||||
|
||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||
do
|
||||
export RANK_ID=${i}
|
||||
export DEVICE_ID=$((i + RANK_START_ID))
|
||||
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
||||
if [ -d ${train_code_path}/device${DEVICE_ID} ]; then
|
||||
rm -rf ${train_code_path}/device${DEVICE_ID}
|
||||
fi
|
||||
mkdir ${train_code_path}/device${DEVICE_ID}
|
||||
cd ${train_code_path}/device${DEVICE_ID} || exit
|
||||
python ${train_code_path}/deep_sort/deep/train.py --data_url=${DATA_PATH} \
|
||||
--train_url=./checkpoint \
|
||||
--run_distribute=True \
|
||||
--run_modelarts=False > out.log 2>&1 &
|
||||
done
|
|
@ -1,126 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
DEVICE_TARGET must choose from ['GPU', 'CPU', 'Ascend']
|
||||
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
model=$(get_real_path $1)
|
||||
dataset_path=$(get_real_path $2)
|
||||
det_path=$(get_real_path $3)
|
||||
|
||||
if [ "$4" == "y" ] || [ "$4" == "n" ];then
|
||||
need_preprocess=$4
|
||||
else
|
||||
echo "weather need preprocess or not, it's value must be in [y, n]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$dataset_path
|
||||
echo "det path: "$det_path
|
||||
echo "need preprocess: "$need_preprocess
|
||||
echo "device id: "$device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function preprocess_data()
|
||||
{
|
||||
if [ -d preprocess_Result ]; then
|
||||
rm -rf ./preprocess_Result
|
||||
fi
|
||||
mkdir preprocess_Result
|
||||
python3.7 ../preprocess.py --data_path=$dataset_path --det_path=$det_path --result_path=./preprocess_Result/ &>preprocess.log
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
cd ../ascend310_infer || exit
|
||||
bash build.sh &> build.log
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
cd - || exit
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
|
||||
../ascend310_infer/out/main --mindir_path=$model --input0_path=./preprocess_Result/00_data --device_id=$device_id &> infer.log
|
||||
|
||||
}
|
||||
|
||||
function generater_detection()
|
||||
{
|
||||
python3.7 ../postprocess.py --det_dir=$det_path --result_dir=./result_Files --output_dir=../detections/ &> detection.log
|
||||
}
|
||||
|
||||
if [ $need_preprocess == "y" ]; then
|
||||
preprocess_data
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "preprocess dataset failed"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
generater_detection
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "generator detection failed"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import deep_sort_app
|
||||
|
||||
from src.sort.iou_matching import iou
|
||||
from src.application_util import visualization
|
||||
|
||||
|
||||
DEFAULT_UPDATE_MS = 20
|
||||
|
||||
|
||||
def run(sequence_dir, result_file, show_false_alarms=False, detection_file=None,
|
||||
update_ms=None, video_filename=None):
|
||||
"""Run tracking result visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
result_file : str
|
||||
Path to the tracking output file in MOTChallenge ground truth format.
|
||||
show_false_alarms : Optional[bool]
|
||||
If True, false alarms are highlighted as red boxes.
|
||||
detection_file : Optional[str]
|
||||
Path to the detection file.
|
||||
update_ms : Optional[int]
|
||||
Number of milliseconds between cosecutive frames. Defaults to (a) the
|
||||
frame rate specified in the seqinfo.ini file or DEFAULT_UDPATE_MS ms if
|
||||
seqinfo.ini is not available.
|
||||
video_filename : Optional[Str]
|
||||
If not None, a video of the tracking results is written to this file.
|
||||
|
||||
"""
|
||||
seq_info = deep_sort_app.gather_sequence_info(sequence_dir, detection_file)
|
||||
results = np.loadtxt(result_file, delimiter=',')
|
||||
|
||||
if show_false_alarms and seq_info["groundtruth"] is None:
|
||||
raise ValueError("No groundtruth available. Cannot show false alarms.")
|
||||
|
||||
def frame_callback(vis, frame_idx):
|
||||
print("Frame idx", frame_idx)
|
||||
image = cv2.imread(
|
||||
seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR)
|
||||
|
||||
vis.set_image(image.copy())
|
||||
|
||||
if seq_info["detections"] is not None:
|
||||
detections = deep_sort_app.create_detections(
|
||||
seq_info["detections"], frame_idx)
|
||||
vis.draw_detections(detections)
|
||||
|
||||
mask = results[:, 0].astype(np.int) == frame_idx
|
||||
track_ids = results[mask, 1].astype(np.int)
|
||||
boxes = results[mask, 2:6]
|
||||
vis.draw_groundtruth(track_ids, boxes)
|
||||
|
||||
if show_false_alarms:
|
||||
groundtruth = seq_info["groundtruth"]
|
||||
mask = groundtruth[:, 0].astype(np.int) == frame_idx
|
||||
gt_boxes = groundtruth[mask, 2:6]
|
||||
for box in boxes:
|
||||
# NOTE(nwojke): This is not strictly correct, because we don't
|
||||
# solve the assignment problem here.
|
||||
min_iou_overlap = 0.5
|
||||
if iou(box, gt_boxes).max() < min_iou_overlap:
|
||||
vis.viewer.color = 0, 0, 255
|
||||
vis.viewer.thickness = 4
|
||||
vis.viewer.rectangle(*box.astype(np.int))
|
||||
|
||||
if update_ms is None:
|
||||
update_ms = seq_info["update_ms"]
|
||||
if update_ms is None:
|
||||
update_ms = DEFAULT_UPDATE_MS
|
||||
visualizer = visualization.Visualization(seq_info, update_ms)
|
||||
if video_filename is not None:
|
||||
visualizer.viewer.enable_videowriter(video_filename)
|
||||
visualizer.run(frame_callback)
|
||||
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Siamese Tracking")
|
||||
parser.add_argument(
|
||||
"--sequence_dir", help="Path to the MOTChallenge sequence directory.",
|
||||
default="../MOT16/train")
|
||||
parser.add_argument(
|
||||
"--result_file", help="Tracking output in MOTChallenge file format.",
|
||||
default="./results/MOT16-01.txt")
|
||||
parser.add_argument(
|
||||
"--detection_file", help="Path to custom detections (optional).",
|
||||
default="../resources/detections/MOT16_POI_test/MOT16-01.npy")
|
||||
parser.add_argument(
|
||||
"--update_ms", help="Time between consecutive frames in milliseconds. "
|
||||
"Defaults to the frame_rate specified in seqinfo.ini, if available.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--output_file", help="Filename of the (optional) output video.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--show_false_alarms", help="Show false alarms as red bounding boxes.",
|
||||
type=bool, default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
run(
|
||||
args.sequence_dir, args.result_file, args.show_false_alarms,
|
||||
args.detection_file, args.update_ms, args.output_file)
|
|
@ -1,356 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
This module contains an image viewer and drawing routines based on OpenCV.
|
||||
"""
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def is_in_bounds(mat, roi):
|
||||
"""Check if ROI is fully contained in the image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mat : ndarray
|
||||
An ndarray of ndim>=2.
|
||||
roi : (int, int, int, int)
|
||||
Region of interest (x, y, width, height) where (x, y) is the top-left
|
||||
corner.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Returns true if the ROI is contain in mat.
|
||||
|
||||
"""
|
||||
if roi[0] < 0 or roi[0] + roi[2] >= mat.shape[1]:
|
||||
return False
|
||||
if roi[1] < 0 or roi[1] + roi[3] >= mat.shape[0]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def view_roi(mat, roi):
|
||||
"""Get sub-array.
|
||||
|
||||
The ROI must be valid, i.e., fully contained in the image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mat : ndarray
|
||||
An ndarray of ndim=2 or ndim=3.
|
||||
roi : (int, int, int, int)
|
||||
Region of interest (x, y, width, height) where (x, y) is the top-left
|
||||
corner.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A view of the roi.
|
||||
|
||||
"""
|
||||
sx, ex = roi[0], roi[0] + roi[2]
|
||||
sy, ey = roi[1], roi[1] + roi[3]
|
||||
if mat.ndim == 2:
|
||||
return mat[sy:ey, sx:ex]
|
||||
return mat[sy:ey, sx:ex, :]
|
||||
|
||||
|
||||
class ImageViewer:
|
||||
"""An image viewer with drawing routines and video capture capabilities.
|
||||
|
||||
Key Bindings:
|
||||
|
||||
* 'SPACE' : pause
|
||||
* 'ESC' : quit
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update_ms : int
|
||||
Number of milliseconds between frames (1000 / frames per second).
|
||||
window_shape : (int, int)
|
||||
Shape of the window (width, height).
|
||||
caption : Optional[str]
|
||||
Title of the window.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
image : ndarray
|
||||
Color image of shape (height, width, 3). You may directly manipulate
|
||||
this image to change the view. Otherwise, you may call any of the
|
||||
drawing routines of this class. Internally, the image is treated as
|
||||
being in BGR color space.
|
||||
|
||||
Note that the image is resized to the the image viewers window_shape
|
||||
just prior to visualization. Therefore, you may pass differently sized
|
||||
images and call drawing routines with the appropriate, original point
|
||||
coordinates.
|
||||
color : (int, int, int)
|
||||
Current BGR color code that applies to all drawing routines.
|
||||
Values are in range [0-255].
|
||||
text_color : (int, int, int)
|
||||
Current BGR text color code that applies to all text rendering
|
||||
routines. Values are in range [0-255].
|
||||
thickness : int
|
||||
Stroke width in pixels that applies to all drawing routines.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, update_ms, window_shape=(640, 480), caption="Figure 1"):
|
||||
self._window_shape = window_shape
|
||||
self._caption = caption
|
||||
self._update_ms = update_ms
|
||||
self._video_writer = None
|
||||
self._user_fun = lambda: None
|
||||
self._terminate = False
|
||||
|
||||
self.image = np.zeros(self._window_shape+(3,), dtype=np.uint8)
|
||||
self._color = (0, 0, 0)
|
||||
self.text_color = (255, 255, 255)
|
||||
self.thickness = 1
|
||||
|
||||
@property
|
||||
def color(self):
|
||||
return self._color
|
||||
|
||||
@color.setter
|
||||
def color(self, value):
|
||||
if len(value) != 3:
|
||||
raise ValueError("color must be tuple of 3")
|
||||
self._color = tuple(int(c) for c in value)
|
||||
|
||||
def rectangle(self, x, y, w, h, label=None):
|
||||
"""Draw a rectangle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float | int
|
||||
Top left corner of the rectangle (x-axis).
|
||||
y : float | int
|
||||
Top let corner of the rectangle (y-axis).
|
||||
w : float | int
|
||||
Width of the rectangle.
|
||||
h : float | int
|
||||
Height of the rectangle.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the top left corner of the
|
||||
rectangle.
|
||||
|
||||
"""
|
||||
pt1 = int(x), int(y)
|
||||
pt2 = int(x + w), int(y + h)
|
||||
cv2.rectangle(self.image, pt1, pt2, self._color, self.thickness)
|
||||
if label is not None:
|
||||
text_size = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_PLAIN, 1, self.thickness)
|
||||
|
||||
center = pt1[0] + 5, pt1[1] + 5 + text_size[0][1]
|
||||
pt2 = pt1[0] + 10 + text_size[0][0], pt1[1] + 10 + \
|
||||
text_size[0][1]
|
||||
cv2.rectangle(self.image, pt1, pt2, self._color, -1)
|
||||
cv2.putText(self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
1, (255, 255, 255), self.thickness)
|
||||
|
||||
def circle(self, x, y, radius, label=None):
|
||||
"""Draw a circle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float | int
|
||||
Center of the circle (x-axis).
|
||||
y : float | int
|
||||
Center of the circle (y-axis).
|
||||
radius : float | int
|
||||
Radius of the circle in pixels.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the center of the circle.
|
||||
|
||||
"""
|
||||
image_size = int(radius + self.thickness + 1.5) # actually half size
|
||||
roi = int(x - image_size), int(y - image_size), \
|
||||
int(2 * image_size), int(2 * image_size)
|
||||
if not is_in_bounds(self.image, roi):
|
||||
return
|
||||
|
||||
image = view_roi(self.image, roi)
|
||||
center = image.shape[1] // 2, image.shape[0] // 2
|
||||
cv2.circle(
|
||||
image, center, int(radius + .5), self._color, self.thickness)
|
||||
if label is not None:
|
||||
cv2.putText(
|
||||
self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def gaussian(self, mean, covariance, label=None):
|
||||
"""Draw 95% confidence ellipse of a 2-D Gaussian distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : array_like
|
||||
The mean vector of the Gaussian distribution (ndim=1).
|
||||
covariance : array_like
|
||||
The 2x2 covariance matrix of the Gaussian distribution.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the center of the ellipse.
|
||||
|
||||
"""
|
||||
# chi2inv(0.95, 2) = 5.9915
|
||||
vals, vecs = np.linalg.eigh(5.9915 * covariance)
|
||||
indices = vals.argsort()[::-1]
|
||||
vals, vecs = np.sqrt(vals[indices]), vecs[:, indices]
|
||||
|
||||
center = int(mean[0] + .5), int(mean[1] + .5)
|
||||
axes = int(vals[0] + .5), int(vals[1] + .5)
|
||||
angle = int(180. * np.arctan2(vecs[1, 0], vecs[0, 0]) / np.pi)
|
||||
cv2.ellipse(
|
||||
self.image, center, axes, angle, 0, 360, self._color, 2)
|
||||
if label is not None:
|
||||
cv2.putText(self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def annotate(self, x, y, text):
|
||||
"""Draws a text string at a given location.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : int | float
|
||||
Bottom-left corner of the text in the image (x-axis).
|
||||
y : int | float
|
||||
Bottom-left corner of the text in the image (y-axis).
|
||||
text : str
|
||||
The text to be drawn.
|
||||
|
||||
"""
|
||||
cv2.putText(self.image, text, (int(x), int(y)), cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def colored_points(self, points, colors=None, skip_index_check=False):
|
||||
"""Draw a collection of points.
|
||||
|
||||
The point size is fixed to 1.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
points : ndarray
|
||||
The Nx2 array of image locations, where the first dimension is
|
||||
the x-coordinate and the second dimension is the y-coordinate.
|
||||
colors : Optional[ndarray]
|
||||
The Nx3 array of colors (dtype=np.uint8). If None, the current
|
||||
color attribute is used.
|
||||
skip_index_check : Optional[bool]
|
||||
If True, index range checks are skipped. This is faster, but
|
||||
requires all points to lie within the image dimensions.
|
||||
|
||||
"""
|
||||
if not skip_index_check:
|
||||
cond1, cond2 = points[:, 0] >= 0, points[:, 0] < 480
|
||||
cond3, cond4 = points[:, 1] >= 0, points[:, 1] < 640
|
||||
indices = np.logical_and.reduce((cond1, cond2, cond3, cond4))
|
||||
points = points[indices, :]
|
||||
if colors is None:
|
||||
colors = np.repeat(
|
||||
self._color, len(points)).reshape(3, len(points)).T
|
||||
indices = (points + .5).astype(np.int)
|
||||
self.image[indices[:, 1], indices[:, 0], :] = colors
|
||||
|
||||
def enable_videowriter(self, output_filename, fourcc_string="MJPG",
|
||||
fps=None):
|
||||
""" Write images to video file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_filename : str
|
||||
Output filename.
|
||||
fourcc_string : str
|
||||
The OpenCV FOURCC code that defines the video codec (check OpenCV
|
||||
documentation for more information).
|
||||
fps : Optional[float]
|
||||
Frames per second. If None, configured according to current
|
||||
parameters.
|
||||
|
||||
"""
|
||||
fourcc = cv2.VideoWriter_fourcc(*fourcc_string)
|
||||
if fps is None:
|
||||
fps = int(1000. / self._update_ms)
|
||||
self._video_writer = cv2.VideoWriter(
|
||||
output_filename, fourcc, fps, self._window_shape)
|
||||
|
||||
def disable_videowriter(self):
|
||||
""" Disable writing videos.
|
||||
"""
|
||||
self._video_writer = None
|
||||
|
||||
def run(self, update_fun=None):
|
||||
"""Start the image viewer.
|
||||
|
||||
This method blocks until the user requests to close the window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update_fun : Optional[Callable[] -> None]
|
||||
An optional callable that is invoked at each frame. May be used
|
||||
to play an animation/a video sequence.
|
||||
|
||||
"""
|
||||
if update_fun is not None:
|
||||
self._user_fun = update_fun
|
||||
|
||||
self._terminate, is_paused = False, False
|
||||
# print("ImageViewer is paused, press space to start.")
|
||||
while not self._terminate:
|
||||
t0 = time.time()
|
||||
if not is_paused:
|
||||
self._terminate = not self._user_fun()
|
||||
if self._video_writer is not None:
|
||||
self._video_writer.write(
|
||||
cv2.resize(self.image, self._window_shape))
|
||||
t1 = time.time()
|
||||
remaining_time = max(1, int(self._update_ms - 1e3*(t1-t0)))
|
||||
cv2.imshow(
|
||||
self._caption, cv2.resize(self.image, self._window_shape[:2]))
|
||||
key = cv2.waitKey(remaining_time)
|
||||
if key & 255 == 27: # ESC
|
||||
print("terminating")
|
||||
self._terminate = True
|
||||
elif key & 255 == 32: # ' '
|
||||
print("toggeling pause: " + str(not is_paused))
|
||||
is_paused = not is_paused
|
||||
elif key & 255 == 115: # 's'
|
||||
print("stepping")
|
||||
self._terminate = not self._user_fun()
|
||||
is_paused = True
|
||||
|
||||
# Due to a bug in OpenCV we must call imshow after destroying the
|
||||
# window. This will make the window appear again as soon as waitKey
|
||||
# is called.
|
||||
#
|
||||
# see https://github.com/Itseez/opencv/issues/4535
|
||||
self.image[:] = 0
|
||||
cv2.destroyWindow(self._caption)
|
||||
cv2.waitKey(1)
|
||||
cv2.imshow(self._caption, self.image)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the control loop.
|
||||
|
||||
After calling this method, the viewer will stop execution before the
|
||||
next frame and hand over control flow to the user.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
"""
|
||||
self._terminate = True
|
|
@ -1,85 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
def non_max_suppression(boxes, max_bbox_overlap, scores=None):
|
||||
"""Suppress overlapping detections.
|
||||
|
||||
Original code from [1]_ has been adapted to include confidence score.
|
||||
|
||||
.. [1] http://www.pyimagesearch.com/2015/02/16/
|
||||
faster-non-maximum-suppression-python/
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> boxes = [d.roi for d in detections]
|
||||
>>> scores = [d.confidence for d in detections]
|
||||
>>> indices = non_max_suppression(boxes, max_bbox_overlap, scores)
|
||||
>>> detections = [detections[i] for i in indices]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
boxes : ndarray
|
||||
Array of ROIs (x, y, width, height).
|
||||
max_bbox_overlap : float
|
||||
ROIs that overlap more than this values are suppressed.
|
||||
scores : Optional[array_like]
|
||||
Detector confidence score.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[int]
|
||||
Returns indices of detections that have survived non-maxima suppression.
|
||||
|
||||
"""
|
||||
if np.size(boxes) == 0:
|
||||
return []
|
||||
|
||||
boxes = boxes.astype(np.float)
|
||||
pick = []
|
||||
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2] + boxes[:, 0]
|
||||
y2 = boxes[:, 3] + boxes[:, 1]
|
||||
|
||||
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
if scores is not None:
|
||||
idxs = np.argsort(scores)
|
||||
else:
|
||||
idxs = np.argsort(y2)
|
||||
|
||||
while np.size(idxs) > 0:
|
||||
last = len(idxs) - 1
|
||||
i = idxs[last]
|
||||
pick.append(i)
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[idxs[:last]])
|
||||
yy1 = np.maximum(y1[i], y1[idxs[:last]])
|
||||
xx2 = np.minimum(x2[i], x2[idxs[:last]])
|
||||
yy2 = np.minimum(y2[i], y2[idxs[:last]])
|
||||
|
||||
w = np.maximum(0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0, yy2 - yy1 + 1)
|
||||
|
||||
overlap = (w * h) / area[idxs[:last]]
|
||||
|
||||
idxs = np.delete(
|
||||
idxs, np.concatenate(
|
||||
([last], np.where(overlap > max_bbox_overlap)[0])))
|
||||
|
||||
return pick
|
|
@ -1,144 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import colorsys
|
||||
import numpy as np
|
||||
from .image_viewer import ImageViewer
|
||||
|
||||
|
||||
def create_unique_color_float(tag, hue_step=0.41):
|
||||
"""Create a unique RGB color code for a given track id (tag).
|
||||
|
||||
The color code is generated in HSV color space by moving along the
|
||||
hue angle and gradually changing the saturation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag : int
|
||||
The unique target identifying tag.
|
||||
hue_step : float
|
||||
Difference between two neighboring color codes in HSV space (more
|
||||
specifically, the distance in hue channel).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(float, float, float)
|
||||
RGB color code in range [0, 1]
|
||||
|
||||
"""
|
||||
h, v = (tag * hue_step) % 1, 1. - (int(tag * hue_step) % 4) / 5.
|
||||
r, g, b = colorsys.hsv_to_rgb(h, 1., v)
|
||||
return r, g, b
|
||||
|
||||
|
||||
def create_unique_color_uchar(tag, hue_step=0.41):
|
||||
"""Create a unique RGB color code for a given track id (tag).
|
||||
|
||||
The color code is generated in HSV color space by moving along the
|
||||
hue angle and gradually changing the saturation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag : int
|
||||
The unique target identifying tag.
|
||||
hue_step : float
|
||||
Difference between two neighboring color codes in HSV space (more
|
||||
specifically, the distance in hue channel).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(int, int, int)
|
||||
RGB color code in range [0, 255]
|
||||
|
||||
"""
|
||||
r, g, b = create_unique_color_float(tag, hue_step)
|
||||
return int(255*r), int(255*g), int(255*b)
|
||||
|
||||
|
||||
class NoVisualization:
|
||||
"""
|
||||
A dummy visualization object that loops through all frames in a given
|
||||
sequence to update the tracker without performing any visualization.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_info):
|
||||
self.frame_idx = seq_info["min_frame_idx"]
|
||||
self.last_idx = seq_info["max_frame_idx"]
|
||||
|
||||
def set_image(self, image):
|
||||
pass
|
||||
|
||||
def draw_groundtruth(self, track_ids, boxes):
|
||||
pass
|
||||
|
||||
def draw_detections(self, detections):
|
||||
pass
|
||||
|
||||
def draw_trackers(self, trackers):
|
||||
pass
|
||||
|
||||
def run(self, frame_callback):
|
||||
while self.frame_idx <= self.last_idx:
|
||||
frame_callback(self, self.frame_idx)
|
||||
self.frame_idx += 1
|
||||
|
||||
|
||||
class Visualization:
|
||||
"""
|
||||
This class shows tracking output in an OpenCV image viewer.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_info, update_ms):
|
||||
image_shape = seq_info["image_size"][::-1]
|
||||
aspect_ratio = float(image_shape[1]) / image_shape[0]
|
||||
image_shape = 1024, int(aspect_ratio * 1024)
|
||||
self.viewer = ImageViewer(
|
||||
update_ms, image_shape, "Figure %s" % seq_info["sequence_name"])
|
||||
self.viewer.thickness = 2
|
||||
self.frame_idx = seq_info["min_frame_idx"]
|
||||
self.last_idx = seq_info["max_frame_idx"]
|
||||
|
||||
def run(self, frame_callback):
|
||||
self.viewer.run(lambda: self._update_fun(frame_callback))
|
||||
|
||||
def _update_fun(self, frame_callback):
|
||||
if self.frame_idx > self.last_idx:
|
||||
return False # Terminate
|
||||
frame_callback(self, self.frame_idx)
|
||||
self.frame_idx += 1
|
||||
return True
|
||||
|
||||
def set_image(self, image):
|
||||
self.viewer.image = image
|
||||
|
||||
def draw_groundtruth(self, track_ids, boxes):
|
||||
self.viewer.thickness = 2
|
||||
for track_id, box in zip(track_ids, boxes):
|
||||
self.viewer.color = create_unique_color_uchar(track_id)
|
||||
self.viewer.rectangle(*box.astype(np.int), label=str(track_id))
|
||||
|
||||
def draw_detections(self, detections):
|
||||
self.viewer.thickness = 2
|
||||
self.viewer.color = 0, 0, 255
|
||||
for detection in detections:
|
||||
self.viewer.rectangle(*detection.tlwh)
|
||||
|
||||
def draw_trackers(self, tracks):
|
||||
self.viewer.thickness = 2
|
||||
for track in tracks:
|
||||
if not track.is_confirmed() or track.time_since_update > 0:
|
||||
continue
|
||||
self.viewer.color = create_unique_color_uchar(track.track_id)
|
||||
self.viewer.rectangle(
|
||||
*track.to_tlwh().astype(np.int), label=str(track.track_id))
|
|
@ -1,75 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import cv2
|
||||
import mindspore
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from .original_model import Net
|
||||
|
||||
class Extractor:
|
||||
def __init__(self, model_path, batch_size=32):
|
||||
self.net = Net(reid=True)
|
||||
self.batch_size = batch_size
|
||||
param_dict = load_checkpoint(model_path)
|
||||
load_param_into_net(self.net, param_dict)
|
||||
self.size = (64, 128)
|
||||
|
||||
def statistic_normalize_img(self, img, statistic_norm=True):
|
||||
"""Statistic normalize images."""
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
if statistic_norm:
|
||||
img = (img - mean) / std
|
||||
img = img.astype(np.float32)
|
||||
return img
|
||||
|
||||
def _preprocess(self, im_crops):
|
||||
"""
|
||||
TODO:
|
||||
1. to float with scale from 0 to 1
|
||||
2. resize to (64, 128) as Market1501 dataset did
|
||||
3. concatenate to a numpy array
|
||||
3. to torch Tensor
|
||||
4. normalize
|
||||
"""
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32)/255., size)
|
||||
im_batch = []
|
||||
for im in im_crops:
|
||||
im = _resize(im, self.size)
|
||||
im = self.statistic_normalize_img(im)
|
||||
im = mindspore.Tensor.from_numpy(im.transpose(2, 0, 1).copy())
|
||||
im = mindspore.ops.ExpandDims()(im, 0)
|
||||
im_batch.append(im)
|
||||
|
||||
im_batch = mindspore.ops.Concat(axis=0)(tuple(im_batch))
|
||||
return im_batch
|
||||
|
||||
|
||||
def __call__(self, im_crops):
|
||||
out = np.zeros((len(im_crops), 128), np.float32)
|
||||
num_batches = int(len(im_crops)/self.batch_size)
|
||||
s, e = 0, 0
|
||||
for i in range(num_batches):
|
||||
s, e = i * self.batch_size, (i + 1) * self.batch_size
|
||||
im_batch = self._preprocess(im_crops[s:e])
|
||||
feature = self.net(im_batch)
|
||||
out[s:e] = feature.asnumpy()
|
||||
if e < len(out):
|
||||
im_batch = self._preprocess(im_crops[e:])
|
||||
feature = self.net(im_batch)
|
||||
out[e:] = feature.asnumpy()
|
||||
return out
|
|
@ -1,124 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as F
|
||||
|
||||
class BasicBlock(nn.Cell):
|
||||
def __init__(self, c_in, c_out, is_downsample=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.add = mindspore.ops.Add()
|
||||
self.ReLU = F.ReLU()
|
||||
self.is_downsample = is_downsample
|
||||
if is_downsample:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, pad_mode='pad', padding=1,\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, pad_mode='same',\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
self.bn1 = nn.BatchNorm2d(c_out, momentum=0.9)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1, pad_mode='pad', padding=1,\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
self.bn2 = nn.BatchNorm2d(c_out, momentum=0.9)
|
||||
if is_downsample:
|
||||
self.downsample = nn.SequentialCell(
|
||||
[nn.Conv2d(c_in, c_out, 1, stride=2, pad_mode='same', has_bias=False, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(c_out, momentum=0.9)]
|
||||
)
|
||||
elif c_in != c_out:
|
||||
self.downsample = nn.SequentialCell(
|
||||
[nn.Conv2d(c_in, c_out, 1, stride=1, pad_mode='pad', has_bias=False, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(c_out, momentum=0.9)]
|
||||
)
|
||||
self.is_downsample = True
|
||||
def construct(self, x):
|
||||
y = self.conv1(x)
|
||||
y = self.bn1(y)
|
||||
y = self.relu(y)
|
||||
y = self.conv2(y)
|
||||
y = self.bn2(y)
|
||||
if self.is_downsample:
|
||||
x = self.downsample(x)
|
||||
y = self.add(x, y)
|
||||
y = self.ReLU(y)
|
||||
return y
|
||||
|
||||
def make_layers(c_in, c_out, repeat_times, is_downsample=False):
|
||||
blocks = []
|
||||
for i in range(repeat_times):
|
||||
if i == 0:
|
||||
blocks.append(BasicBlock(c_in, c_out, is_downsample=is_downsample))
|
||||
else:
|
||||
blocks.append(BasicBlock(c_out, c_out))
|
||||
return nn.SequentialCell(blocks)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, num_classes=751, reid=False, ascend=False):
|
||||
super(Net, self).__init__()
|
||||
# 3 128 64
|
||||
self.conv = nn.SequentialCell(
|
||||
[nn.Conv2d(3, 32, 3, stride=1, pad_mode='same', has_bias=True, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(32, momentum=0.9),
|
||||
nn.ELU(),
|
||||
nn.Conv2d(32, 32, 3, stride=1, pad_mode='same', has_bias=True, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(32, momentum=0.9),
|
||||
nn.ELU(),
|
||||
nn.MaxPool2d(3, 2, pad_mode='same')]
|
||||
)
|
||||
#]
|
||||
# 32 64 32
|
||||
self.layer1 = make_layers(32, 32, 2, False)
|
||||
# 32 64 32
|
||||
self.layer2 = make_layers(32, 64, 2, True)
|
||||
# 64 32 16
|
||||
self.layer3 = make_layers(64, 128, 2, True)
|
||||
# 128 16 8
|
||||
self.dp = nn.Dropout(keep_prob=0.6)
|
||||
self.dense = nn.Dense(128*16*8, 128)
|
||||
self.bn1 = nn.BatchNorm1d(128, momentum=0.9)
|
||||
self.elu = nn.ELU()
|
||||
# 256 1 1
|
||||
self.reid = reid
|
||||
self.ascend = ascend
|
||||
#self.flatten = nn.Flatten()
|
||||
self.div = F.Div()
|
||||
self.batch_norm = nn.BatchNorm1d(128, momentum=0.9)
|
||||
self.classifier = nn.Dense(128, num_classes)
|
||||
self.Norm = nn.Norm(axis=0, keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
#x = self.flatten(x)
|
||||
x = x.view((x.shape[0], -1))
|
||||
if self.reid:
|
||||
x = self.dp(x)
|
||||
x = self.dense(x)
|
||||
if self.ascend:
|
||||
x = self.bn1(x)
|
||||
else:
|
||||
f = self.Norm(x)
|
||||
x = self.div(x, f)
|
||||
return x
|
||||
x = self.dp(x)
|
||||
x = self.dense(x)
|
||||
x = self.bn1(x)
|
||||
x = self.elu(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
|
@ -1,153 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.nn as nn
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common import set_seed
|
||||
from original_model import Net
|
||||
set_seed(1234)
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument("--epoch", help="Path to custom detections.", type=int, default=100)
|
||||
parser.add_argument("--batch_size", help="Batch size for Training.", type=int, default=8)
|
||||
parser.add_argument("--num_parallel_workers", help="The number of parallel workers.", type=int, default=16)
|
||||
parser.add_argument("--pre_train", help='The ckpt file of model.', type=str, default=None)
|
||||
parser.add_argument("--save_check_point", help="Whether save the training resulting.", type=bool, default=True)
|
||||
|
||||
#learning rate
|
||||
parser.add_argument("--learning_rate", help="Learning rate.", type=float, default=0.1)
|
||||
parser.add_argument("--decay_epoch", help="decay epochs.", type=int, default=20)
|
||||
parser.add_argument('--gamma', type=float, default=0.10, help='learning rate decay.')
|
||||
parser.add_argument("--momentum", help="", type=float, default=0.9)
|
||||
|
||||
#run on where
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: 0)')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=True, help='Run distribute')
|
||||
|
||||
return parser.parse_args()
|
||||
def get_lr(base_lr, total_epochs, steps_per_epoch, step_size, gamma):
|
||||
lr_each_step = []
|
||||
for i in range(1, total_epochs+1):
|
||||
if i % step_size == 0:
|
||||
base_lr *= gamma
|
||||
for _ in range(steps_per_epoch):
|
||||
lr_each_step.append(base_lr)
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
if args.run_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data'
|
||||
local_train_url = '/cache/train'
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num,\
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
DATA_DIR = local_data_url + '/'
|
||||
else:
|
||||
if args.run_distribute:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,\
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
device_num = 1
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
device_id = args.device_id
|
||||
DATA_DIR = args.data_url + '/'
|
||||
|
||||
data = ds.ImageFolderDataset(DATA_DIR, decode=True, shuffle=True,\
|
||||
num_parallel_workers=args.num_parallel_workers, num_shards=device_num, shard_id=device_id)
|
||||
|
||||
transform_img = [
|
||||
C.RandomCrop((128, 64), padding=4),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize([0.485*255, 0.456*255, 0.406*255], [0.229*255, 0.224*255, 0.225*255]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
num_classes = max(data.num_classes(), 0)
|
||||
|
||||
data = data.map(input_columns="image", operations=transform_img, num_parallel_workers=args.num_parallel_workers)
|
||||
data = data.batch(batch_size=args.batch_size)
|
||||
|
||||
data_size = data.get_dataset_size()
|
||||
|
||||
loss_cb = LossMonitor(data_size)
|
||||
time_cb = TimeMonitor(data_size=data_size)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
|
||||
#save training results
|
||||
if args.save_check_point and (device_num == 1 or device_id == 0):
|
||||
|
||||
model_save_path = './ckpt_' + str(6) + '/'
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=data_size*args.epoch, keep_checkpoint_max=args.epoch)
|
||||
|
||||
if args.run_modelarts:
|
||||
ckpoint_cb = ModelCheckpoint(prefix='deepsort', directory=local_train_url, config=config_ck)
|
||||
else:
|
||||
ckpoint_cb = ModelCheckpoint(prefix='deepsort', directory=model_save_path, config=config_ck)
|
||||
callbacks += [ckpoint_cb]
|
||||
|
||||
#design learning rate
|
||||
lr = Tensor(get_lr(args.learning_rate, args.epoch, data_size, args.decay_epoch, args.gamma))
|
||||
# net definition
|
||||
net = Net(num_classes=num_classes)
|
||||
|
||||
# loss and optimizer
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=args.momentum)
|
||||
#optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=args.momentum, weight_decay=5e-4)
|
||||
#optimizer = mindspore.nn.Momentum(params = net.trainable_params(), learning_rate=lr, momentum=args.momentum)
|
||||
|
||||
#train
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer)
|
||||
|
||||
model.train(args.epoch, data, callbacks=callbacks, dataset_sink_mode=True)
|
||||
if args.run_modelarts:
|
||||
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
|
|
@ -1,62 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Detection:
|
||||
"""
|
||||
This class represents a bounding box detection in a single image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tlwh : array_like
|
||||
Bounding box in format `(x, y, w, h)`.
|
||||
confidence : float
|
||||
Detector confidence score.
|
||||
feature : array_like
|
||||
A feature vector that describes the object contained in this image.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
tlwh : ndarray
|
||||
Bounding box in format `(top left x, top left y, width, height)`.
|
||||
confidence : ndarray
|
||||
Detector confidence score.
|
||||
feature : ndarray | NoneType
|
||||
A feature vector that describes the object contained in this image.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tlwh, confidence, feature):
|
||||
self.tlwh = np.asarray(tlwh, dtype=np.float)
|
||||
self.confidence = float(confidence)
|
||||
self.feature = np.asarray(feature, dtype=np.float32)
|
||||
|
||||
def to_tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def to_xyah(self):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
|
@ -1,94 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import linear_assignment
|
||||
|
||||
|
||||
def iou(bbox, candidates):
|
||||
"""Computer intersection over union.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox : ndarray
|
||||
A bounding box in format `(top left x, top left y, width, height)`.
|
||||
candidates : ndarray
|
||||
A matrix of candidate bounding boxes (one per row) in the same format
|
||||
as `bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The intersection over union in [0, 1] between the `bbox` and each
|
||||
candidate. A higher score means a larger fraction of the `bbox` is
|
||||
occluded by the candidate.
|
||||
|
||||
"""
|
||||
bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
|
||||
candidates_tl = candidates[:, :2]
|
||||
candidates_br = candidates[:, :2] + candidates[:, 2:]
|
||||
|
||||
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
|
||||
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
|
||||
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
|
||||
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
|
||||
wh = np.maximum(0., br - tl)
|
||||
|
||||
area_intersection = wh.prod(axis=1)
|
||||
area_bbox = bbox[2:].prod()
|
||||
area_candidates = candidates[:, 2:].prod(axis=1)
|
||||
return area_intersection / (area_bbox + area_candidates - area_intersection)
|
||||
|
||||
|
||||
def iou_cost(tracks, detections, track_indices=None,
|
||||
detection_indices=None):
|
||||
"""An intersection over union distance metric.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tracks : List[deep_sort.track.Track]
|
||||
A list of tracks.
|
||||
detections : List[deep_sort.detection.Detection]
|
||||
A list of detections.
|
||||
track_indices : Optional[List[int]]
|
||||
A list of indices to tracks that should be matched. Defaults to
|
||||
all `tracks`.
|
||||
detection_indices : Optional[List[int]]
|
||||
A list of indices to detections that should be matched. Defaults
|
||||
to all `detections`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a cost matrix of shape
|
||||
len(track_indices), len(detection_indices) where entry (i, j) is
|
||||
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if tracks[track_idx].time_since_update > 1:
|
||||
cost_matrix[row, :] = linear_assignment.INFTY_COST
|
||||
continue
|
||||
|
||||
bbox = tracks[track_idx].to_tlwh()
|
||||
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
|
||||
cost_matrix[row, :] = 1. - iou(bbox, candidates)
|
||||
return cost_matrix
|
|
@ -1,237 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
||||
chi2inv95 = {
|
||||
1: 3.8415,
|
||||
2: 5.9915,
|
||||
3: 7.8147,
|
||||
4: 9.4877,
|
||||
5: 11.070,
|
||||
6: 12.592,
|
||||
7: 14.067,
|
||||
8: 15.507,
|
||||
9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilter:
|
||||
"""
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(self._motion_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(
|
||||
projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((
|
||||
kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements,
|
||||
only_position=False):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
d = measurements - mean
|
||||
z = scipy.linalg.solve_triangular(
|
||||
cholesky_factor, d.T, lower=True, check_finite=False,
|
||||
overwrite_b=True)
|
||||
squared_maha = np.sum(z * z, axis=0)
|
||||
return squared_maha
|
|
@ -1,205 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
# from sklearn.utils.linear_assignment_ import linear_assignment
|
||||
from scipy.optimize import linear_sum_assignment as linear_assignment
|
||||
from . import kalman_filter
|
||||
|
||||
|
||||
INFTY_COST = 1e+5
|
||||
|
||||
|
||||
def min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections, track_indices=None,
|
||||
detection_indices=None):
|
||||
"""Solve linear assignment problem.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as well as
|
||||
a list of N track indices and M detection indices. The metric should
|
||||
return the NxM dimensional cost matrix, where element (i, j) is the
|
||||
association cost between the i-th track in the given track indices and
|
||||
the j-th detection in the given detection_indices.
|
||||
max_distance : float
|
||||
Gating threshold. Associations with cost larger than this value are
|
||||
disregarded.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : List[int]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above).
|
||||
detection_indices : List[int]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(List[(int, int)], List[int], List[int])
|
||||
Returns a tuple with the following three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
if not detection_indices or not track_indices:
|
||||
return [], track_indices, detection_indices # Nothing to match.
|
||||
|
||||
cost_matrix = distance_metric(
|
||||
tracks, detections, track_indices, detection_indices)
|
||||
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
|
||||
|
||||
row_indices, col_indices = linear_assignment(cost_matrix)
|
||||
|
||||
matches, unmatched_tracks, unmatched_detections = [], [], []
|
||||
for col, detection_idx in enumerate(detection_indices):
|
||||
if col not in col_indices:
|
||||
unmatched_detections.append(detection_idx)
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if row not in row_indices:
|
||||
unmatched_tracks.append(track_idx)
|
||||
for row, col in zip(row_indices, col_indices):
|
||||
track_idx = track_indices[row]
|
||||
detection_idx = detection_indices[col]
|
||||
if cost_matrix[row, col] > max_distance:
|
||||
unmatched_tracks.append(track_idx)
|
||||
unmatched_detections.append(detection_idx)
|
||||
else:
|
||||
matches.append((track_idx, detection_idx))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def matching_cascade(
|
||||
distance_metric, max_distance, cascade_depth, tracks, detections,
|
||||
track_indices=None, detection_indices=None):
|
||||
"""Run matching cascade.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as well as
|
||||
a list of N track indices and M detection indices. The metric should
|
||||
return the NxM dimensional cost matrix, where element (i, j) is the
|
||||
association cost between the i-th track in the given track indices and
|
||||
the j-th detection in the given detection indices.
|
||||
max_distance : float
|
||||
Gating threshold. Associations with cost larger than this value are
|
||||
disregarded.
|
||||
cascade_depth: int
|
||||
The cascade depth, should be se to the maximum track age.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : Optional[List[int]]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above). Defaults to all tracks.
|
||||
detection_indices : Optional[List[int]]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above). Defaults to all
|
||||
detections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(List[(int, int)], List[int], List[int])
|
||||
Returns a tuple with the following three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = list(range(len(tracks)))
|
||||
if detection_indices is None:
|
||||
detection_indices = list(range(len(detections)))
|
||||
|
||||
unmatched_detections = detection_indices
|
||||
matches = []
|
||||
for level in range(cascade_depth):
|
||||
if not unmatched_detections: # No detections left
|
||||
break
|
||||
|
||||
track_indices_l = [
|
||||
k for k in track_indices
|
||||
if tracks[k].time_since_update == 1 + level
|
||||
]
|
||||
if not track_indices_l: # Nothing to match at this level
|
||||
continue
|
||||
|
||||
matches_l, _, unmatched_detections = \
|
||||
min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections,
|
||||
track_indices_l, unmatched_detections)
|
||||
matches += matches_l
|
||||
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def gate_cost_matrix(
|
||||
kf, cost_matrix, tracks, detections, track_indices, detection_indices,
|
||||
gated_cost=INFTY_COST, only_position=False):
|
||||
"""Invalidate infeasible entries in cost matrix based on the state
|
||||
distributions obtained by Kalman filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : The Kalman filter.
|
||||
cost_matrix : ndarray
|
||||
The NxM dimensional cost matrix, where N is the number of track indices
|
||||
and M is the number of detection indices, such that entry (i, j) is the
|
||||
association cost between `tracks[track_indices[i]]` and
|
||||
`detections[detection_indices[j]]`.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : List[int]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above).
|
||||
detection_indices : List[int]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above).
|
||||
gated_cost : Optional[float]
|
||||
Entries in the cost matrix corresponding to infeasible associations are
|
||||
set this value. Defaults to a very large value.
|
||||
only_position : Optional[bool]
|
||||
If True, only the x, y position of the state distribution is considered
|
||||
during gating. Defaults to False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns the modified cost matrix.
|
||||
|
||||
"""
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray(
|
||||
[detections[i].to_xyah() for i in detection_indices])
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
track = tracks[track_idx]
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
|
||||
return cost_matrix
|
|
@ -1,190 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _pdist(a, b):
|
||||
"""Compute pair-wise squared distance between points in `a` and `b`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
An NxM matrix of N samples of dimensionality M.
|
||||
b : array_like
|
||||
An LxM matrix of L samples of dimensionality M.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||
contains the squared distance between `a[i]` and `b[j]`.
|
||||
|
||||
"""
|
||||
a, b = np.asarray(a), np.asarray(b)
|
||||
if np.size(a) == 0 or np.size(b) == 0:
|
||||
return np.zeros((len(a), len(b)))
|
||||
a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
|
||||
r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
|
||||
r2 = np.clip(r2, 0., float(np.inf))
|
||||
return r2
|
||||
|
||||
|
||||
def _cosine_distance(a, b, data_is_normalized=False):
|
||||
"""Compute pair-wise cosine distance between points in `a` and `b`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
An NxM matrix of N samples of dimensionality M.
|
||||
b : array_like
|
||||
An LxM matrix of L samples of dimensionality M.
|
||||
data_is_normalized : Optional[bool]
|
||||
If True, assumes rows in a and b are unit length vectors.
|
||||
Otherwise, a and b are explicitly normalized to length 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||
contains the squared distance between `a[i]` and `b[j]`.
|
||||
|
||||
"""
|
||||
if not data_is_normalized:
|
||||
a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
|
||||
b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
|
||||
return 1. - np.dot(a, b.T)
|
||||
|
||||
|
||||
def _nn_euclidean_distance(x, y):
|
||||
""" Helper function for nearest neighbor distance metric (Euclidean).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
A matrix of N row-vectors (sample points).
|
||||
y : ndarray
|
||||
A matrix of M row-vectors (query points).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A vector of length M that contains for each entry in `y` the
|
||||
smallest Euclidean distance to a sample in `x`.
|
||||
|
||||
"""
|
||||
distances = _pdist(x, y)
|
||||
return np.maximum(0.0, distances.min(axis=0))
|
||||
|
||||
|
||||
def _nn_cosine_distance(x, y):
|
||||
""" Helper function for nearest neighbor distance metric (cosine).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
A matrix of N row-vectors (sample points).
|
||||
y : ndarray
|
||||
A matrix of M row-vectors (query points).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A vector of length M that contains for each entry in `y` the
|
||||
smallest cosine distance to a sample in `x`.
|
||||
|
||||
"""
|
||||
distances = _cosine_distance(x, y)
|
||||
return distances.min(axis=0)
|
||||
|
||||
|
||||
class NearestNeighborDistanceMetric:
|
||||
"""
|
||||
A nearest neighbor distance metric that, for each target, returns
|
||||
the closest distance to any sample that has been observed so far.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
Either "euclidean" or "cosine".
|
||||
matching_threshold: float
|
||||
The matching threshold. Samples with larger distance are considered an
|
||||
invalid match.
|
||||
budget : Optional[int]
|
||||
If not None, fix samples per class to at most this number. Removes
|
||||
the oldest samples when the budget is reached.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samples : Dict[int -> List[ndarray]]
|
||||
A dictionary that maps from target identities to the list of samples
|
||||
that have been observed so far.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metric, matching_threshold, budget=None):
|
||||
|
||||
|
||||
if metric == "euclidean":
|
||||
self._metric = _nn_euclidean_distance
|
||||
elif metric == "cosine":
|
||||
self._metric = _nn_cosine_distance
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid metric; must be either 'euclidean' or 'cosine'")
|
||||
self.matching_threshold = matching_threshold
|
||||
self.budget = budget
|
||||
self.samples = {}
|
||||
|
||||
def partial_fit(self, features, targets, active_targets):
|
||||
"""Update the distance metric with new data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : ndarray
|
||||
An NxM matrix of N features of dimensionality M.
|
||||
targets : ndarray
|
||||
An integer array of associated target identities.
|
||||
active_targets : List[int]
|
||||
A list of targets that are currently present in the scene.
|
||||
|
||||
"""
|
||||
for feature, target in zip(features, targets):
|
||||
self.samples.setdefault(target, []).append(feature)
|
||||
if self.budget is not None:
|
||||
self.samples[target] = self.samples[target][-int(self.budget):]
|
||||
self.samples = {k: self.samples[k] for k in active_targets}
|
||||
|
||||
def distance(self, features, targets):
|
||||
"""Compute distance between features and targets.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : ndarray
|
||||
An NxM matrix of N features of dimensionality M.
|
||||
targets : List[int]
|
||||
A list of targets to match the given `features` against.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a cost matrix of shape len(targets), len(features), where
|
||||
element (i, j) contains the closest squared distance between
|
||||
`targets[i]` and `features[j]`.
|
||||
|
||||
"""
|
||||
cost_matrix = np.zeros((len(targets), len(features)))
|
||||
for i, target in enumerate(targets):
|
||||
cost_matrix[i, :] = self._metric(self.samples[target], features)
|
||||
return cost_matrix
|
|
@ -1,178 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
class TrackState:
|
||||
"""
|
||||
Enumeration type for the single target track state. Newly created tracks are
|
||||
classified as `tentative` until enough evidence has been collected. Then,
|
||||
the track state is changed to `confirmed`. Tracks that are no longer alive
|
||||
are classified as `deleted` to mark them for removal from the set of active
|
||||
tracks.
|
||||
|
||||
"""
|
||||
|
||||
Tentative = 1
|
||||
Confirmed = 2
|
||||
Deleted = 3
|
||||
|
||||
|
||||
class Track:
|
||||
"""
|
||||
A single target track with state space `(x, y, a, h)` and associated
|
||||
velocities, where `(x, y)` is the center of the bounding box, `a` is the
|
||||
aspect ratio and `h` is the height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector of the initial state distribution.
|
||||
covariance : ndarray
|
||||
Covariance matrix of the initial state distribution.
|
||||
track_id : int
|
||||
A unique track identifier.
|
||||
n_init : int
|
||||
Number of consecutive detections before the track is confirmed. The
|
||||
track state is set to `Deleted` if a miss occurs within the first
|
||||
`n_init` frames.
|
||||
max_age : int
|
||||
The maximum number of consecutive misses before the track state is
|
||||
set to `Deleted`.
|
||||
feature : Optional[ndarray]
|
||||
Feature vector of the detection this track originates from. If not None,
|
||||
this feature is added to the `features` cache.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector of the initial state distribution.
|
||||
covariance : ndarray
|
||||
Covariance matrix of the initial state distribution.
|
||||
track_id : int
|
||||
A unique track identifier.
|
||||
hits : int
|
||||
Total number of measurement updates.
|
||||
age : int
|
||||
Total number of frames since first occurrence.
|
||||
time_since_update : int
|
||||
Total number of frames since last measurement update.
|
||||
state : TrackState
|
||||
The current track state.
|
||||
features : List[ndarray]
|
||||
A cache of features. On each measurement update, the associated feature
|
||||
vector is added to this list.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mean, covariance, track_id, n_init, max_age,
|
||||
feature=None):
|
||||
self.mean = mean
|
||||
self.covariance = covariance
|
||||
self.track_id = track_id
|
||||
self.hits = 1
|
||||
self.age = 1
|
||||
self.time_since_update = 0
|
||||
|
||||
self.state = TrackState.Tentative
|
||||
self.features = []
|
||||
if feature is not None:
|
||||
self.features.append(feature)
|
||||
|
||||
self._n_init = n_init
|
||||
self._max_age = max_age
|
||||
|
||||
def to_tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The bounding box.
|
||||
|
||||
"""
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
def to_tlbr(self):
|
||||
"""Get current position in bounding box format `(min x, miny, max x,
|
||||
max y)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The bounding box.
|
||||
|
||||
"""
|
||||
ret = self.to_tlwh()
|
||||
ret[2:] = ret[:2] + ret[2:]
|
||||
return ret
|
||||
|
||||
def predict(self, kf):
|
||||
"""Propagate the state distribution to the current time step using a
|
||||
Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : kalman_filter.KalmanFilter
|
||||
The Kalman filter.
|
||||
|
||||
"""
|
||||
self.mean, self.covariance = kf.predict(self.mean, self.covariance)
|
||||
self.age += 1
|
||||
self.time_since_update += 1
|
||||
|
||||
def update(self, kf, detection):
|
||||
"""Perform Kalman filter measurement update step and update the feature
|
||||
cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : kalman_filter.KalmanFilter
|
||||
The Kalman filter.
|
||||
detection : Detection
|
||||
The associated detection.
|
||||
|
||||
"""
|
||||
self.mean, self.covariance = kf.update(
|
||||
self.mean, self.covariance, detection.to_xyah())
|
||||
self.features.append(detection.feature)
|
||||
|
||||
self.hits += 1
|
||||
self.time_since_update = 0
|
||||
if self.state == TrackState.Tentative and self.hits >= self._n_init:
|
||||
self.state = TrackState.Confirmed
|
||||
|
||||
def mark_missed(self):
|
||||
"""Mark this track as missed (no association at the current time step).
|
||||
"""
|
||||
if self.state == TrackState.Tentative:
|
||||
self.state = TrackState.Deleted
|
||||
elif self.time_since_update > self._max_age:
|
||||
self.state = TrackState.Deleted
|
||||
|
||||
def is_tentative(self):
|
||||
"""Returns True if this track is tentative (unconfirmed).
|
||||
"""
|
||||
return self.state == TrackState.Tentative
|
||||
|
||||
def is_confirmed(self):
|
||||
"""Returns True if this track is confirmed."""
|
||||
return self.state == TrackState.Confirmed
|
||||
|
||||
def is_deleted(self):
|
||||
"""Returns True if this track is dead and should be deleted."""
|
||||
return self.state == TrackState.Deleted
|
|
@ -1,152 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import kalman_filter
|
||||
from . import linear_assignment
|
||||
from . import iou_matching
|
||||
from .track import Track
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""
|
||||
This is the multi-target tracker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : nn_matching.NearestNeighborDistanceMetric
|
||||
A distance metric for measurement-to-track association.
|
||||
max_age : int
|
||||
Maximum number of missed misses before a track is deleted.
|
||||
n_init : int
|
||||
Number of consecutive detections before the track is confirmed. The
|
||||
track state is set to `Deleted` if a miss occurs within the first
|
||||
`n_init` frames.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
metric : nn_matching.NearestNeighborDistanceMetric
|
||||
The distance metric used for measurement to track association.
|
||||
max_age : int
|
||||
Maximum number of missed misses before a track is deleted.
|
||||
n_init : int
|
||||
Number of frames that a track remains in initialization phase.
|
||||
kf : kalman_filter.KalmanFilter
|
||||
A Kalman filter to filter target trajectories in image space.
|
||||
tracks : List[Track]
|
||||
The list of active tracks at the current time step.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3):
|
||||
self.metric = metric
|
||||
self.max_iou_distance = max_iou_distance
|
||||
self.max_age = max_age
|
||||
self.n_init = n_init
|
||||
|
||||
self.kf = kalman_filter.KalmanFilter()
|
||||
self.tracks = []
|
||||
self._next_id = 1
|
||||
|
||||
def predict(self):
|
||||
"""Propagate track state distributions one time step forward.
|
||||
|
||||
This function should be called once every time step, before `update`.
|
||||
"""
|
||||
for track in self.tracks:
|
||||
track.predict(self.kf)
|
||||
|
||||
def update(self, detections):
|
||||
"""Perform measurement update and track management.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detections : List[deep_sort.detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
|
||||
"""
|
||||
# Run matching cascade.
|
||||
matches, unmatched_tracks, unmatched_detections = \
|
||||
self._match(detections)
|
||||
|
||||
# Update track set.
|
||||
for track_idx, detection_idx in matches:
|
||||
self.tracks[track_idx].update(
|
||||
self.kf, detections[detection_idx])
|
||||
for track_idx in unmatched_tracks:
|
||||
self.tracks[track_idx].mark_missed()
|
||||
for detection_idx in unmatched_detections:
|
||||
self._initiate_track(detections[detection_idx])
|
||||
self.tracks = [t for t in self.tracks if not t.is_deleted()]
|
||||
|
||||
# Update distance metric.
|
||||
active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
|
||||
features, targets = [], []
|
||||
for track in self.tracks:
|
||||
if not track.is_confirmed():
|
||||
continue
|
||||
features += track.features
|
||||
targets += [track.track_id for _ in track.features]
|
||||
track.features = []
|
||||
self.metric.partial_fit(
|
||||
np.asarray(features), np.asarray(targets), active_targets)
|
||||
|
||||
def _match(self, detections):
|
||||
|
||||
def gated_metric(tracks, dets, track_indices, detection_indices):
|
||||
features = np.array([dets[i].feature for i in detection_indices])
|
||||
targets = np.array([tracks[i].track_id for i in track_indices])
|
||||
cost_matrix = self.metric.distance(features, targets)
|
||||
cost_matrix = linear_assignment.gate_cost_matrix(
|
||||
self.kf, cost_matrix, tracks, dets, track_indices,
|
||||
detection_indices)
|
||||
|
||||
return cost_matrix
|
||||
|
||||
# Split track set into confirmed and unconfirmed tracks.
|
||||
confirmed_tracks = [
|
||||
i for i, t in enumerate(self.tracks) if t.is_confirmed()]
|
||||
unconfirmed_tracks = [
|
||||
i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
|
||||
|
||||
# Associate confirmed tracks using appearance features.
|
||||
matches_a, unmatched_tracks_a, unmatched_detections = \
|
||||
linear_assignment.matching_cascade(
|
||||
gated_metric, self.metric.matching_threshold, self.max_age,
|
||||
self.tracks, detections, confirmed_tracks)
|
||||
|
||||
# Associate remaining tracks together with unconfirmed tracks using IOU.
|
||||
iou_track_candidates = unconfirmed_tracks + [
|
||||
k for k in unmatched_tracks_a if
|
||||
self.tracks[k].time_since_update == 1]
|
||||
unmatched_tracks_a = [
|
||||
k for k in unmatched_tracks_a if
|
||||
self.tracks[k].time_since_update != 1]
|
||||
matches_b, unmatched_tracks_b, unmatched_detections = \
|
||||
linear_assignment.min_cost_matching(
|
||||
iou_matching.iou_cost, self.max_iou_distance, self.tracks,
|
||||
detections, iou_track_candidates, unmatched_detections)
|
||||
|
||||
matches = matches_a + matches_b
|
||||
unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
def _initiate_track(self, detection):
|
||||
mean, covariance = self.kf.initiate(detection.to_xyah())
|
||||
self.tracks.append(Track(
|
||||
mean, covariance, self._next_id, self.n_init, self.max_age,
|
||||
detection.feature))
|
||||
self._next_id += 1
|
|
@ -1,540 +0,0 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [FCN 介绍](#fcn-介绍)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速开始](#快速开始)
|
||||
- [脚本介绍](#脚本介绍)
|
||||
- [脚本以及简单代码](#脚本以及简单代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [生成数据步骤](#生成数据步骤)
|
||||
- [训练数据](#训练数据)
|
||||
- [训练步骤](#训练步骤)
|
||||
- [训练](#训练)
|
||||
- [评估步骤](#评估步骤)
|
||||
- [评估](#评估)
|
||||
- [导出过程](#导出过程)
|
||||
- [导出](#导出)
|
||||
- [推理过程](#推理过程)
|
||||
- [推理](#推理)
|
||||
- [模型介绍](#模型介绍)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [FCN8s on PASCAL VOC 2012](#fcn8s-on-pascal-voc-2012)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [FCN8s on PASCAL VOC](#fcn8s-on-pascal-voc)
|
||||
- [如何使用](#如何使用)
|
||||
- [教程](#教程)
|
||||
- [Set context](#set-context)
|
||||
- [Load dataset](#load-dataset)
|
||||
- [Define model](#define-model)
|
||||
- [optimizer](#optimizer)
|
||||
- [loss scale](#loss-scale)
|
||||
- [callback for saving ckpts](#callback-for-saving-ckpts)
|
||||
- [随机事件介绍](#随机事件介绍)
|
||||
- [ModelZoo 主页](#modelzoo-主页)
|
||||
|
||||
# [FCN 介绍](#contents)
|
||||
|
||||
FCN主要用用于图像分割领域,是一种端到端的分割方法。FCN丢弃了全连接层,使得其能够处理任意大小的图像,且减少了模型的参数量,提高了模型的分割速度。FCN在编码部分使用了VGG的结构,在解码部分中使用反卷积/上采样操作恢复图像的分辨率。FCN-8s最后使用8倍的反卷积/上采样操作将输出分割图恢复到与输入图像相同大小。
|
||||
|
||||
[Paper]: Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
|
||||
|
||||
# [模型架构](#contents)
|
||||
|
||||
FCN-8s使用丢弃全连接操作的VGG16作为编码部分,并分别融合VGG16中第3,4,5个池化层特征,最后使用stride=8的反卷积获得分割图像。
|
||||
|
||||
# [数据集](#contents)
|
||||
|
||||
Dataset used:
|
||||
|
||||
[PASCAL VOC 2012](<http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html>)
|
||||
|
||||
[SBD](<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz>)
|
||||
|
||||
# [环境要求](#contents)
|
||||
|
||||
- 硬件(Ascend/GPU)
|
||||
- 需要准备具有Ascend或GPU处理能力的硬件环境.
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需获取更多信息,请查看如下链接:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
# [快速开始](#contents)
|
||||
|
||||
在通过官方网站安装MindSpore之后,你可以通过如下步骤开始训练以及评估:
|
||||
|
||||
```backbone
|
||||
vgg16训练ImageNet数据集的ckpt文件做为FCN8s的backbone
|
||||
vgg16网络路径: model_zoo/official/cv/vgg16
|
||||
```
|
||||
|
||||
```default_config.yaml
|
||||
data_file: /home/DataSet/voc2012/vocaug_mindrecords/vocaug.mindrecord0
|
||||
ckpt_vgg16: /home/DataSet/predtrained/vgg16_predtrained.ckpt
|
||||
data_root: /home/DataSet/voc2012/VOCdevkit/VOC2012
|
||||
data_lst: /home/DataSet/voc2012/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt
|
||||
ckpt_file: /home/FCN8s/ckpt/FCN8s_1-133_300.ckpt
|
||||
|
||||
根据本地数据存放路径修改参数
|
||||
```
|
||||
|
||||
- running on Ascend with default parameters
|
||||
|
||||
```python
|
||||
# Ascend单卡训练示例
|
||||
python train.py --device_id device_id
|
||||
|
||||
# Ascend评估示例
|
||||
python eval.py --device_id device_id
|
||||
```
|
||||
|
||||
- running on GPU with gpu default parameters
|
||||
|
||||
```python
|
||||
# GPU单卡训练示例
|
||||
python train.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
|
||||
# GPU多卡训练示例
|
||||
export RANK_SIZE=8
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
|
||||
# GPU评估示例
|
||||
python eval.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
```
|
||||
|
||||
# [脚本介绍](#contents)
|
||||
|
||||
## [脚本以及简单代码](#contents)
|
||||
|
||||
```python
|
||||
├── model_zoo
|
||||
├── README.md // descriptions about all the models
|
||||
├── FCN8s
|
||||
├── README.md // descriptions about FCN
|
||||
├── ascend310_infer // 实现310推理源代码
|
||||
├── scripts
|
||||
├── run_train.sh
|
||||
├── run_standalone_train.sh
|
||||
├── run_standalone_train_gpu.sh // train in gpu with single device
|
||||
├── run_distribute_train_gpu.sh // train in gpu with multi device
|
||||
├── run_eval.sh
|
||||
├── run_infer_310.sh // Ascend推理shell脚本
|
||||
├── build_data.sh
|
||||
├── src
|
||||
│ ├──data
|
||||
│ ├──build_seg_data.py // creating dataset
|
||||
│ ├──dataset.py // loading dataset
|
||||
│ ├──nets
|
||||
│ ├──FCN8s.py // FCN-8s architecture
|
||||
│ ├──loss
|
||||
│ ├──loss.py // loss function
|
||||
│ ├──utils
|
||||
│ ├──lr_scheduler.py // getting learning_rateFCN-8s
|
||||
│ ├──model_utils
|
||||
│ ├──config.py // getting config parameters
|
||||
│ ├──device_adapter.py // getting device info
|
||||
│ ├──local_adapter.py // getting device info
|
||||
│ ├──moxing_adapter.py // Decorator
|
||||
├── default_config.yaml // Ascend parameters config
|
||||
├── gpu_default_config.yaml // GPU parameters config
|
||||
├── train.py // training script
|
||||
├── postprogress.py // 310推理后处理脚本
|
||||
├── export.py // 将checkpoint文件导出到air/mindir
|
||||
├── eval.py // evaluation script
|
||||
```
|
||||
|
||||
## [脚本参数](#contents)
|
||||
|
||||
训练以及评估的参数可以在default_config.yaml中设置
|
||||
|
||||
- config for FCN8s
|
||||
|
||||
```default_config.yaml
|
||||
# dataset
|
||||
'data_file': '/data/workspace/mindspore_dataset/FCN/FCN/dataset/MINDRECORED_NAME.mindrecord', # path and name of one mindrecord file
|
||||
'train_batch_size': 32,
|
||||
'crop_size': 512,
|
||||
'image_mean': [103.53, 116.28, 123.675],
|
||||
'image_std': [57.375, 57.120, 58.395],
|
||||
'min_scale': 0.5,
|
||||
'max_scale': 2.0,
|
||||
'ignore_label': 255,
|
||||
'num_classes': 21,
|
||||
|
||||
# optimizer
|
||||
'train_epochs': 500,
|
||||
'base_lr': 0.015,
|
||||
'loss_scale': 1024.0,
|
||||
|
||||
# model
|
||||
'model': 'FCN8s',
|
||||
'ckpt_vgg16': '',
|
||||
'ckpt_pre_trained': '',
|
||||
|
||||
# train
|
||||
'save_steps': 330,
|
||||
'keep_checkpoint_max': 5,
|
||||
'ckpt_dir': './ckpt',
|
||||
```
|
||||
|
||||
如需获取更多信息,Ascend请查看`default_config.yaml`, GPU请查看`gpu_default_config.yaml`.
|
||||
|
||||
## [生成数据步骤](#contents)
|
||||
|
||||
### 训练数据
|
||||
|
||||
- build mindrecord training data
|
||||
|
||||
```python
|
||||
bash build_data.sh
|
||||
or
|
||||
python src/data/build_seg_data.py --data_root=/home/sun/data/Mindspore/benchmark_RELEASE/dataset \
|
||||
--data_lst=/home/sun/data/Mindspore/benchmark_RELEASE/dataset/trainaug.txt \
|
||||
--dst_path=dataset/MINDRECORED_NAME.mindrecord \
|
||||
--num_shards=1 \
|
||||
--shuffle=True
|
||||
data_root: 训练数据集的总目录包含两个子目录img和cls_png,img目录下存放训练图像,cls_png目录下存放标签mask图像,
|
||||
data_lst: 存放训练样本的名称列表文档,每行一个样本。
|
||||
dst_path: 生成mindrecord数据的目标位置
|
||||
```
|
||||
|
||||
## [训练步骤](#contents)
|
||||
|
||||
### 训练
|
||||
|
||||
- running on Ascend with default parameters
|
||||
|
||||
```python
|
||||
# Ascend单卡训练示例
|
||||
python train.py --device_id device_id
|
||||
or
|
||||
bash scripts/run_standalone_train.sh [DEVICE_ID]
|
||||
# example: bash scripts/run_standalone_train.sh 0
|
||||
|
||||
#Ascend八卡并行训练
|
||||
bash scripts/run_train.sh [DEVICE_NUM] rank_table.json
|
||||
# example: bash scripts/run_train.sh 8 /home/hccl_8p_01234567_10.155.170.71.json
|
||||
```
|
||||
|
||||
分布式训练需要提前创建JSON格式的HCCL配置文件,请遵循[链接说明](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
|
||||
|
||||
- running on GPU with gpu default parameters
|
||||
|
||||
```python
|
||||
# GPU单卡训练示例
|
||||
python train.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
or
|
||||
bash scripts/run_standalone_train_gpu.sh DEVICE_ID
|
||||
|
||||
# GPU八卡训练示例
|
||||
export RANK_SIZE=8
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
or
|
||||
bash run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
|
||||
|
||||
# GPU评估示例
|
||||
python eval.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
```
|
||||
|
||||
训练时,训练过程中的epch和step以及此时的loss和精确度会呈现log.txt中:
|
||||
|
||||
```python
|
||||
epoch: * step: **, loss is ****
|
||||
...
|
||||
```
|
||||
|
||||
此模型的checkpoint会在默认路径下存储
|
||||
|
||||
- 如果要在modelarts上进行模型的训练,可以参考modelarts的[官方指导文档](https://support.huaweicloud.com/modelarts/) 开始进行模型的训练和推理,具体操作如下:
|
||||
|
||||
```ModelArts
|
||||
# 在ModelArts上使用分布式训练示例:
|
||||
# 数据集存放方式
|
||||
|
||||
# ├── VOC2012 # dir
|
||||
# ├── VOCdevkit # VOCdevkit dir
|
||||
# ├── Please refer to VOCdevkit structure
|
||||
# ├── benchmark_RELEASE # benchmark_RELEASE dir
|
||||
# ├── Please refer to benchmark_RELEASE structure
|
||||
# ├── backbone # backbone dir
|
||||
# ├── vgg_predtrained.ckpt
|
||||
# ├── predtrained # predtrained dir
|
||||
# ├── FCN8s_1-133_300.ckpt
|
||||
# ├── checkpoint # checkpoint dir
|
||||
# ├── FCN8s_1-133_300.ckpt
|
||||
# ├── vocaug_mindrecords # train dataset dir
|
||||
# ├── voctrain.mindrecords0
|
||||
# ├── voctrain.mindrecords0.db
|
||||
# ├── voctrain.mindrecords1
|
||||
# ├── voctrain.mindrecords1.db
|
||||
# ├── voctrain.mindrecords2
|
||||
# ├── voctrain.mindrecords2.db
|
||||
# ├── voctrain.mindrecords3
|
||||
# ├── voctrain.mindrecords3.db
|
||||
# ├── voctrain.mindrecords4
|
||||
# ├── voctrain.mindrecords4.db
|
||||
# ├── voctrain.mindrecords5
|
||||
# ├── voctrain.mindrecords5.db
|
||||
# ├── voctrain.mindrecords6
|
||||
# ├── voctrain.mindrecords6.db
|
||||
# ├── voctrain.mindrecords7
|
||||
# ├── voctrain.mindrecords7.db
|
||||
|
||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "ckpt_dir=/cache/train/outputs_FCN8s/"
|
||||
# 设置 "ckpt_vgg16=/cache/data/backbone/vgg_predtrain file" 如果没有预训练 ckpt_vgg16=""
|
||||
# 设置 "ckpt_pre_trained=/cache/data/predtrained/pred file" 如果无需继续训练 ckpt_pre_trained=""
|
||||
# 设置 "data_file=/cache/data/vocaug_mindrecords/voctrain.mindrecords0"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
|
||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/FCN8s"
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "train.py"
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../VOC2012"(选择VOC2012文件夹路径)
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path"
|
||||
# (6) 开始模型的训练
|
||||
|
||||
# 在modelarts上使用模型推理的示例
|
||||
# (1) 把训练好的模型地方到桶的对应位置
|
||||
# (2) 选择a或者b其中一种方式
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "data_root=/cache/data/VOCdevkit/VOC2012/"
|
||||
# 设置 "data_lst=./ImageSets/Segmentation/val.txt"
|
||||
# 设置 "ckpt_file=/cache/data/checkpoint/ckpt file name"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
|
||||
# (3) 设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) 在modelarts的界面上设置代码的路径 "/path/FCN8s"
|
||||
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py"
|
||||
# (6) 在modelarts的界面上设置模型的数据路径 ".../VOC2012"(选择VOC2012文件夹路径) ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path"
|
||||
# (7) 开始模型的推理
|
||||
```
|
||||
|
||||
## [评估步骤](#contents)
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend或GPU上使用PASCAL VOC 2012 验证集进行评估
|
||||
|
||||
在使用命令运行前,请检查用于评估的checkpoint的路径。请设置路径为到checkpoint的绝对路径,如 "/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt"。
|
||||
|
||||
- eval on Ascend
|
||||
|
||||
```python
|
||||
python eval.py
|
||||
```
|
||||
|
||||
```shell 评估
|
||||
bash scripts/run_eval.sh DATA_ROOT DATA_LST CKPT_PATH
|
||||
# example: bash scripts/run_eval.sh /home/DataSet/voc2012/VOCdevkit/VOC2012 \
|
||||
# /home/DataSet/voc2012/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt /home/FCN8s/ckpt/FCN8s_1-133_300.ckpt
|
||||
```
|
||||
|
||||
以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以类似如下方式呈现:
|
||||
|
||||
```python
|
||||
mean IoU 0.6467
|
||||
```
|
||||
|
||||
## 导出过程
|
||||
|
||||
### 导出
|
||||
|
||||
在导出之前需要修改default_config.yaml配置文件中的ckpt_file配置项,file_name和file_format配置项根据情况修改.
|
||||
|
||||
```shell
|
||||
python export.py
|
||||
```
|
||||
|
||||
- 在modelarts上导出MindIR
|
||||
|
||||
```Modelarts
|
||||
在ModelArts上导出MindIR示例
|
||||
数据集存放方式同Modelart训练
|
||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "file_name=fcn8s"
|
||||
# 设置 "file_format=MINDIR"
|
||||
# 设置 "ckpt_file=/cache/data/checkpoint file name"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/fcn8s"。
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "export.py" 。
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../VOC2012/checkpoint"(选择VOC2012/checkpoint文件夹路径) ,
|
||||
# MindIR的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
```
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 推理
|
||||
|
||||
在还行推理之前我们需要先导出模型。Air模型只能在昇腾910环境上导出,mindir可以在任意环境上导出。batch_size只支持1。
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_LIST_FILE] [IMAGE_PATH] [MASK_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
推理的结果保存在当前目录下,在acc.log日志文件中可以找到类似以下的结果。
|
||||
|
||||
```python
|
||||
mean IoU 0.0.64519877
|
||||
```
|
||||
|
||||
- eval on GPU
|
||||
|
||||
```python
|
||||
python eval.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
```
|
||||
|
||||
以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以类似如下方式呈现:
|
||||
|
||||
```python
|
||||
mean IoU 0.6472
|
||||
```
|
||||
|
||||
# [模型介绍](#contents)
|
||||
|
||||
## [性能](#contents)
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### FCN8s on PASCAL VOC 2012
|
||||
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | ------------------------------------------------------------| -------------------------------------------------|
|
||||
| Model Version | FCN-8s | FCN-8s |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| uploaded Date | 12/30/2020 (month/day/year) | 06/11/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 | 1.2.0 |
|
||||
| Dataset | PASCAL VOC 2012 and SBD | PASCAL VOC 2012 and SBD |
|
||||
| Training Parameters | epoch=500, steps=330, batch_size = 32, lr=0.015 | epoch=500, steps=330, batch_size = 8, lr=0.005 |
|
||||
| Optimizer | Momentum | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||
| outputs | probability | probability |
|
||||
| Loss | 0.038 | 0.036 |
|
||||
| Speed | 1pc: 564.652 ms/step; | 1pc: 455.460 ms/step; |
|
||||
| Scripts | [FCN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/FCN8s)
|
||||
|
||||
### Inference Performance
|
||||
|
||||
#### FCN8s on PASCAL VOC
|
||||
|
||||
| Parameters | Ascend | GPU
|
||||
| ------------------- | --------------------------- | ---------------------------
|
||||
| Model Version | FCN-8s | FCN-8s
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G
|
||||
| Uploaded Date | 10/29/2020 (month/day/year) | 06/11/2021 (month/day/year)
|
||||
| MindSpore Version | 1.1.0 | 1.2.0
|
||||
| Dataset | PASCAL VOC 2012 | PASCAL VOC 2012
|
||||
| batch_size | 16 | 16
|
||||
| outputs | probability | probability
|
||||
| mean IoU | 64.67 | 64.72
|
||||
|
||||
## [如何使用](#contents)
|
||||
|
||||
### 教程
|
||||
|
||||
如果你需要在不同硬件平台(如GPU,Ascend 910 或者 Ascend 310)使用训练好的模型,你可以参考这个 [Link](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/multi_platform_inference.html)。以下是一个简单例子的步骤介绍:
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```
|
||||
# Set context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
|
||||
context.set_auto_parallel_context(device_num=device_num,parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
init()
|
||||
|
||||
# Load dataset
|
||||
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
||||
image_std=cfg.image_std,
|
||||
data_file=cfg.data_file,
|
||||
batch_size=cfg.batch_size,
|
||||
crop_size=cfg.crop_size,
|
||||
max_scale=cfg.max_scale,
|
||||
min_scale=cfg.min_scale,
|
||||
ignore_label=cfg.ignore_label,
|
||||
num_classes=cfg.num_classes,
|
||||
num_readers=2,
|
||||
num_parallel_calls=4,
|
||||
shard_id=args.rank,
|
||||
shard_num=args.group_size)
|
||||
dataset = dataset.get_dataset(repeat=1)
|
||||
|
||||
# Define model
|
||||
net = FCN8s(n_class=cfg.num_classes)
|
||||
loss_ = loss.SoftmaxCrossEntropyLoss(cfg.num_classes, cfg.ignore_label)
|
||||
|
||||
# optimizer
|
||||
iters_per_epoch = dataset.get_dataset_size()
|
||||
total_train_steps = iters_per_epoch * cfg.train_epochs
|
||||
|
||||
lr_scheduler = CosineAnnealingLR(cfg.base_lr,
|
||||
cfg.train_epochs,
|
||||
iters_per_epoch,
|
||||
cfg.train_epochs,
|
||||
warmup_epochs=0,
|
||||
eta_min=0)
|
||||
lr = Tensor(lr_scheduler.get_lr())
|
||||
|
||||
# loss scale
|
||||
manager_loss_scale = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
|
||||
loss_scale=cfg.loss_scale)
|
||||
|
||||
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
|
||||
|
||||
# callback for saving ckpts
|
||||
time_cb = TimeMonitor(data_size=iters_per_epoch)
|
||||
loss_cb = LossMonitor()
|
||||
cbs = [time_cb, loss_cb]
|
||||
|
||||
if args.rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.ckpt_dir, config=config_ck)
|
||||
cbs.append(ckpoint_cb)
|
||||
|
||||
model.train(cfg.train_epochs, dataset, callbacks=cbs)
|
||||
|
||||
# [随机事件介绍](#contents)
|
||||
|
||||
我们在train.py中设置了随机种子
|
||||
|
||||
# [ModelZoo 主页](#contents)
|
||||
|
||||
请查看官方网站 [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -1,29 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -1,33 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
std::vector<std::string> GetImagesById(const std::string &idFIle, const std::string &dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
#endif
|
|
@ -1,223 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "../inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::dataset::TensorTransform;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
using mindspore::dataset::vision::Pad;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::SwapRedBlue;
|
||||
using mindspore::dataset::vision::Decode;
|
||||
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(image_list, "", "image list");
|
||||
DEFINE_string(dataset_path, ".", "dataset path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
const int IMAGEWIDTH = 512;
|
||||
const int IMAGEHEIGHT = 512;
|
||||
|
||||
int PadImage(const MSTensor &input, MSTensor *output) {
|
||||
std::shared_ptr<TensorTransform> normalize(new Normalize({103.53, 116.28, 123.675},
|
||||
{57.375, 57.120, 58.395}));
|
||||
Execute composeNormalize({normalize});
|
||||
std::vector<int64_t> shape = input.Shape();
|
||||
auto imgResize = MSTensor();
|
||||
auto imgNormalize = MSTensor();
|
||||
|
||||
float widthScale, heightScale;
|
||||
widthScale = static_cast<float>(IMAGEWIDTH) / shape[1];
|
||||
heightScale = static_cast<float>(IMAGEHEIGHT) / shape[0];
|
||||
Status ret;
|
||||
if (widthScale < heightScale) {
|
||||
int heightSize = shape[0]*widthScale;
|
||||
std::shared_ptr<TensorTransform> resize(new Resize({heightSize, IMAGEWIDTH}));
|
||||
Execute composeResizeWidth({resize});
|
||||
ret = composeResizeWidth(input, &imgResize);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Resize Width failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
ret = composeNormalize(imgResize, &imgNormalize);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Normalize failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int paddingSize = IMAGEHEIGHT - heightSize;
|
||||
std::shared_ptr<TensorTransform> pad(new Pad({0, 0, 0, paddingSize}));
|
||||
Execute composePad({pad});
|
||||
ret = composePad(imgNormalize, output);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Height Pad failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
int widthSize = shape[1]*heightScale;
|
||||
std::shared_ptr<TensorTransform> resize(new Resize({IMAGEHEIGHT, widthSize}));
|
||||
Execute composeResizeHeight({resize});
|
||||
ret = composeResizeHeight(input, &imgResize);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Resize Height failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
ret = composeNormalize(imgResize, &imgNormalize);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Normalize failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int paddingSize = IMAGEWIDTH - widthSize;
|
||||
std::shared_ptr<TensorTransform> pad(new Pad({0, 0, paddingSize, 0}));
|
||||
Execute composePad({pad});
|
||||
ret = composePad(imgNormalize, output);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Width Pad failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
ascend310->SetPrecisionMode("allow_fp32_to_fp16");
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||
|
||||
Model model;
|
||||
Status ret = model.Build(GraphCell(graph), context);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
std::vector<MSTensor> model_inputs = model.GetInputs();
|
||||
if (model_inputs.empty()) {
|
||||
std::cout << "Invalid model, inputs is empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto all_files = GetImagesById(FLAGS_image_list, FLAGS_dataset_path);
|
||||
if (all_files.empty()) {
|
||||
std::cout << "ERROR: no input data." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = all_files.size();
|
||||
std::shared_ptr<TensorTransform> decode(new Decode());
|
||||
Execute composeDecode({decode});
|
||||
std::shared_ptr<TensorTransform> hwc2chw(new HWC2CHW());
|
||||
Execute composeTranspose({hwc2chw});
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::string file = all_files[i] + ".jpg";
|
||||
std::cout << "Start predict input files:" << file << std::endl;
|
||||
auto imgDecode = MSTensor();
|
||||
|
||||
auto image = ReadFileToTensor(file);
|
||||
ret = composeDecode(image, &imgDecode);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Decode failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
auto imgPad = MSTensor();
|
||||
PadImage(imgDecode, &imgPad);
|
||||
auto img = MSTensor();
|
||||
composeTranspose(imgPad, &img);
|
||||
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
img.Data().get(), img.DataSize());
|
||||
|
||||
gettimeofday(&start, nullptr);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << file << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
WriteResult(file, outputs);
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -1,145 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "../inc/utils.h"
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<std::string> GetImagesById(const std::string &idFile, const std::string &dirName) {
|
||||
std::ifstream readFile(idFile);
|
||||
std::string id;
|
||||
std::vector<std::string> result;
|
||||
|
||||
if (!readFile.is_open()) {
|
||||
std::cout << "can not open image id txt file" << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
while (getline(readFile, id)) {
|
||||
result.emplace_back(dirName + "/" + id);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -1,91 +0,0 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
# ======================================================================================
|
||||
# common options
|
||||
|
||||
crop_size: 512
|
||||
image_mean: [103.53, 116.28, 123.675]
|
||||
image_std: [57.375, 57.120, 58.395]
|
||||
ignore_label: 255
|
||||
num_classes: 21
|
||||
model: "FCN8s"
|
||||
parallel_mode: "data_parallel"
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
train_batch_size: 32
|
||||
min_scale: 0.5
|
||||
max_scale: 2.0
|
||||
data_file: "/data/mjq/dataset/vocaug_local_mindrecords/vocaug_local_mindrecords.mindrecords"
|
||||
|
||||
# optimizer
|
||||
train_epochs: 500
|
||||
base_lr: 0.015
|
||||
loss_scale: 1024
|
||||
|
||||
# model
|
||||
ckpt_vgg16: "/data/mjq/ckpt/vgg16_predtrain.ckpt"
|
||||
ckpt_pre_trained: ""
|
||||
save_steps: 330
|
||||
keep_checkpoint_max: 5
|
||||
ckpt_dir: "./ckpt"
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
eval_batch_size: 16
|
||||
data_root: "/data/mjq/dataset/VOCdevkit/VOC2012"
|
||||
data_lst: "/data/mjq/dataset/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt"
|
||||
scales: [1.0]
|
||||
flip: False
|
||||
freeze_bn: False
|
||||
ckpt_file: "/data/mjq/ckpt/FCN8s_1-133_300.ckpt"
|
||||
|
||||
# ======================================================================================
|
||||
# Export options
|
||||
file_name: "fcn8s"
|
||||
file_format: MINDIR
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
crop_size: "crop_size"
|
||||
image_mean: "image_mean"
|
||||
image_std: "image std"
|
||||
ignore_label: "ignore label"
|
||||
num_classes: "number of classes"
|
||||
model: "select model"
|
||||
data_file: "path of train data"
|
||||
train_batch_size: "train_batch_size"
|
||||
min_scale: "min scales of train"
|
||||
max_scale: "max scales of train"
|
||||
train_epochs: "train epoch"
|
||||
base_lr: "base lr"
|
||||
loss_scale: "loss scales"
|
||||
ckpt_vgg16: "backbone pretrain"
|
||||
ckpt_pre_trained: "model pretrain"
|
||||
data_root: "root path of val data"
|
||||
eval_batch_size: "eval batch size"
|
||||
data_lst: "list of val data"
|
||||
scales: "scales of evaluation"
|
||||
flip: "freeze bn"
|
||||
ckpt_file: "model to evaluate"
|
||||
file_name: "export file name"
|
||||
file_format: "export model type"
|
|
@ -1,187 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval FCN8s."""
|
||||
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.nets.FCN8s import FCN8s
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
|
||||
def cal_hist(a, b, n):
|
||||
k = (a >= 0) & (a < n)
|
||||
return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n)
|
||||
|
||||
|
||||
def resize_long(img, long_size=513):
|
||||
h, w, _ = img.shape
|
||||
if h > w:
|
||||
new_h = long_size
|
||||
new_w = int(1.0 * long_size * w / h)
|
||||
else:
|
||||
new_w = long_size
|
||||
new_h = int(1.0 * long_size * h / w)
|
||||
imo = cv2.resize(img, (new_w, new_h))
|
||||
return imo
|
||||
|
||||
|
||||
class BuildEvalNetwork(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(BuildEvalNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.softmax = nn.Softmax(axis=1)
|
||||
|
||||
def construct(self, input_data):
|
||||
output = self.network(input_data)
|
||||
output = self.softmax(output)
|
||||
return output
|
||||
|
||||
|
||||
def pre_process(configs, img_, crop_size=512):
|
||||
# resize
|
||||
img_ = resize_long(img_, crop_size)
|
||||
resize_h, resize_w, _ = img_.shape
|
||||
|
||||
# mean, std
|
||||
image_mean = np.array(configs.image_mean)
|
||||
image_std = np.array(configs.image_std)
|
||||
img_ = (img_ - image_mean) / image_std
|
||||
|
||||
# pad to crop_size
|
||||
pad_h = crop_size - img_.shape[0]
|
||||
pad_w = crop_size - img_.shape[1]
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
# hwc to chw
|
||||
img_ = img_.transpose((2, 0, 1))
|
||||
return img_, resize_h, resize_w
|
||||
|
||||
|
||||
def eval_batch(configs, eval_net, img_lst, crop_size=512, flip=True):
|
||||
result_lst = []
|
||||
batch_size = len(img_lst)
|
||||
batch_img = np.zeros((configs.eval_batch_size, 3, crop_size, crop_size), dtype=np.float32)
|
||||
resize_hw = []
|
||||
for l in range(batch_size):
|
||||
img_ = img_lst[l]
|
||||
img_, resize_h, resize_w = pre_process(configs, img_, crop_size)
|
||||
batch_img[l] = img_
|
||||
resize_hw.append([resize_h, resize_w])
|
||||
|
||||
batch_img = np.ascontiguousarray(batch_img)
|
||||
net_out = eval_net(Tensor(batch_img, mstype.float32))
|
||||
net_out = net_out.asnumpy()
|
||||
|
||||
if flip:
|
||||
batch_img = batch_img[:, :, :, ::-1]
|
||||
net_out_flip = eval_net(Tensor(batch_img, mstype.float32))
|
||||
net_out += net_out_flip.asnumpy()[:, :, :, ::-1]
|
||||
|
||||
for bs in range(batch_size):
|
||||
probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0))
|
||||
ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1]
|
||||
probs_ = cv2.resize(probs_.astype(np.float32), (ori_w, ori_h))
|
||||
result_lst.append(probs_)
|
||||
|
||||
return result_lst
|
||||
|
||||
|
||||
def eval_batch_scales(configs, eval_net, img_lst, scales,
|
||||
base_crop_size=512, flip=True):
|
||||
sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales]
|
||||
probs_lst = eval_batch(configs, eval_net, img_lst, crop_size=sizes_[0], flip=flip)
|
||||
print(sizes_)
|
||||
for crop_size_ in sizes_[1:]:
|
||||
probs_lst_tmp = eval_batch(configs, eval_net, img_lst, crop_size=crop_size_, flip=flip)
|
||||
for pl, _ in enumerate(probs_lst):
|
||||
probs_lst[pl] += probs_lst_tmp[pl]
|
||||
|
||||
result_msk = []
|
||||
for i in probs_lst:
|
||||
result_msk.append(i.argmax(axis=2))
|
||||
return result_msk
|
||||
|
||||
|
||||
@moxing_wrapper()
|
||||
def net_eval():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id(),
|
||||
save_graphs=False)
|
||||
|
||||
# data list
|
||||
with open(config.data_lst) as f:
|
||||
img_lst = f.readlines()
|
||||
|
||||
net = FCN8s(n_class=config.num_classes)
|
||||
|
||||
# load model
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# evaluate
|
||||
hist = np.zeros((config.num_classes, config.num_classes))
|
||||
batch_img_lst = []
|
||||
batch_msk_lst = []
|
||||
bi = 0
|
||||
image_num = 0
|
||||
for i, line in enumerate(img_lst):
|
||||
|
||||
img_name = line.strip('\n')
|
||||
data_root = config.data_root
|
||||
img_path = data_root + '/JPEGImages/' + str(img_name) + '.jpg'
|
||||
msk_path = data_root + '/SegmentationClass/' + str(img_name) + '.png'
|
||||
|
||||
img_ = np.array(Image.open(img_path), dtype=np.uint8)
|
||||
msk_ = np.array(Image.open(msk_path), dtype=np.uint8)
|
||||
|
||||
batch_img_lst.append(img_)
|
||||
batch_msk_lst.append(msk_)
|
||||
bi += 1
|
||||
if bi == config.eval_batch_size:
|
||||
batch_res = eval_batch_scales(config, net, batch_img_lst, scales=config.scales,
|
||||
base_crop_size=config.crop_size, flip=config.flip)
|
||||
for mi in range(config.eval_batch_size):
|
||||
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), config.num_classes)
|
||||
|
||||
bi = 0
|
||||
batch_img_lst = []
|
||||
batch_msk_lst = []
|
||||
print('processed {} images'.format(i+1))
|
||||
image_num = i
|
||||
|
||||
if bi > 0:
|
||||
batch_res = eval_batch_scales(config, net, batch_img_lst, scales=config.scales,
|
||||
base_crop_size=config.crop_size, flip=config.flip)
|
||||
for mi in range(bi):
|
||||
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), config.num_classes)
|
||||
print('processed {} images'.format(image_num + 1))
|
||||
|
||||
print(hist)
|
||||
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
|
||||
print('per-class IoU', iu)
|
||||
print('mean IoU', np.nanmean(iu))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net_eval()
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export FCN8s."""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from src.nets.FCN8s import FCN8s
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
config.file_name = os.path.join(config.output_path, config.file_name)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def model_export():
|
||||
net = FCN8s(n_class=config.num_classes)
|
||||
|
||||
# load model
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([1, 3, config.crop_size, config.crop_size]), ms.float32)
|
||||
export(net, input_arr, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_export()
|
|
@ -1,86 +0,0 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "GPU"
|
||||
enable_profiling: False
|
||||
checkpoint_path: "./checkpoint/"
|
||||
checkpoint_file: "./checkpoint/.ckpt"
|
||||
# ======================================================================================
|
||||
# common options
|
||||
|
||||
crop_size: 512
|
||||
image_mean: [103.53, 116.28, 123.675]
|
||||
image_std: [57.375, 57.120, 58.395]
|
||||
ignore_label: 255
|
||||
num_classes: 21
|
||||
model: "FCN8s"
|
||||
parallel_mode: "data_parallel"
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
train_batch_size: 8
|
||||
min_scale: 0.5
|
||||
max_scale: 2.0
|
||||
data_file: "./vocaug_local_mindrecords/vocaug_local_mindrecords.mindrecords" # change to your own path of train data
|
||||
|
||||
# optimizer
|
||||
train_epochs: 500
|
||||
base_lr: 0.005
|
||||
loss_scale: 1024
|
||||
|
||||
# model
|
||||
ckpt_vgg16: "./vgg16_predtrain.ckpt" # change to your own path of backbone pretrain
|
||||
ckpt_pre_trained: ""
|
||||
save_steps: 330
|
||||
keep_checkpoint_max: 5
|
||||
ckpt_dir: "./ckpt"
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
eval_batch_size: 16
|
||||
data_root: "./VOCdevkit/VOC2012" # change to your own path of val data
|
||||
data_lst: "./VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt" # change to your own path of val data list
|
||||
scales: [1.0]
|
||||
flip: False
|
||||
freeze_bn: False
|
||||
ckpt_file: "./FCN8s_1-500_220.ckpt" # change to your own path of evaluate model
|
||||
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
crop_size: "crop_size"
|
||||
image_mean: "image_mean"
|
||||
image_std: "image std"
|
||||
ignore_label: "ignore label"
|
||||
num_classes: "number of classes"
|
||||
model: "select model"
|
||||
data_file: "path of train data"
|
||||
train_batch_size: "train_batch_size"
|
||||
min_scale: "min scales of train"
|
||||
max_scale: "max scales of train"
|
||||
train_epochs: "train epoch"
|
||||
base_lr: "base lr"
|
||||
loss_scale: "loss scales"
|
||||
ckpt_vgg16: "backbone pretrain"
|
||||
ckpt_pre_trained: "model pretrain"
|
||||
data_root: "root path of val data"
|
||||
eval_batch_size: "eval batch size"
|
||||
data_lst: "list of val data"
|
||||
scales: "scales of evaluation"
|
||||
flip: "freeze bn"
|
||||
ckpt_file: "model to evaluate"
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.nets.FCN8s import FCN8s
|
||||
|
||||
def fcn8s_net(*args, **kwargs):
|
||||
return FCN8s(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about FCN8s"""
|
||||
if name == "fcn8s":
|
||||
num_classes = 21
|
||||
return fcn8s_net(n_class=num_classes, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
parser = argparse.ArgumentParser(description="FasterRcnn inference")
|
||||
parser.add_argument("--image_list", type=str, required=True, help="result file path.")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result file path.")
|
||||
parser.add_argument("--data_path", type=str, required=True, help="mask file path.")
|
||||
parser.add_argument("--mask_path", type=str, required=True, help="mask file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
NUM_CLASSES = 21
|
||||
|
||||
def get_img_size(file_name):
|
||||
img = Image.open(file_name)
|
||||
return img.size
|
||||
|
||||
def get_resized_size(org_h, org_w, long_size=512):
|
||||
if org_h > org_w:
|
||||
new_h = long_size
|
||||
new_w = int(1.0 * long_size * org_w / org_h)
|
||||
else:
|
||||
new_w = long_size
|
||||
new_h = int(1.0 * long_size * org_h / org_w)
|
||||
|
||||
return new_h, new_w
|
||||
|
||||
def cal_hist(a, b, n):
|
||||
k = (a >= 0) & (a < n)
|
||||
return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n)
|
||||
|
||||
def cal_acc(image_list, data_path, result_path, mask_path):
|
||||
hist = np.zeros((NUM_CLASSES, NUM_CLASSES))
|
||||
with open(image_list) as f:
|
||||
img_list = f.readlines()
|
||||
|
||||
for img in img_list:
|
||||
img_file = os.path.join(data_path, img.strip() + ".jpg")
|
||||
org_width, org_height = get_img_size(img_file)
|
||||
|
||||
resize_h, resize_w = get_resized_size(org_height, org_width)
|
||||
|
||||
result_file = os.path.join(result_path, img.strip() + "_0.bin")
|
||||
result = np.fromfile(result_file, dtype=np.float32).reshape(21, 512, 512)
|
||||
probs_ = result[:, :resize_h, :resize_w].transpose((1, 2, 0))
|
||||
probs_ = cv2.resize(probs_.astype(np.float32), (org_width, org_height))
|
||||
result_msk = probs_.argmax(axis=2)
|
||||
|
||||
mask_file = os.path.join(mask_path, img.strip() + ".png")
|
||||
mask = np.array(Image.open(mask_file), dtype=np.uint8)
|
||||
|
||||
hist += cal_hist(mask.flatten(), result_msk.flatten(), NUM_CLASSES)
|
||||
|
||||
#print(hist)
|
||||
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
|
||||
print('per-class IoU', iu)
|
||||
print('mean IoU', np.nanmean(iu))
|
||||
|
||||
if __name__ == '__main__':
|
||||
cal_acc(args.image_list, args.data_path, args.result_path, args.mask_path)
|
|
@ -1,4 +0,0 @@
|
|||
numpy
|
||||
opencv-python
|
||||
pyyaml
|
||||
pillow
|
|
@ -1,22 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=0
|
||||
python src/data/build_seg_data.py --data_root=/data/mjq/dataset \
|
||||
--data_lst=/data/mjq/dataset/vocaug_train_lst.txt \
|
||||
--dst_path=./mindrecords/vocaug_train.mindrecords \
|
||||
--num_shards=1 \
|
||||
--shuffle=True
|
|
@ -1,55 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
export RANK_SIZE=$1
|
||||
PROJECT_DIR=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
TRAIN_DATA_DIR=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $TRAIN_DATA_DIR ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$TRAIN_DATA_DIR is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "distribute_train" ]; then
|
||||
rm -rf ./distribute_train
|
||||
fi
|
||||
|
||||
mkdir ./distribute_train
|
||||
cp ./*.py ./distribute_train
|
||||
cp ./*.yaml ./distribute_train
|
||||
cp -r ./src ./distribute_train
|
||||
cd ./distribute_train || exit
|
||||
|
||||
CONFIG_FILE="$PROJECT_DIR/../gpu_default_config.yaml"
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
nohup python train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--device_target=GPU > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -1,44 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_distribute_eval.sh DEVICE_NUM RANK_TABLE_FILE DATASET CKPT_PATH"
|
||||
echo "for example: sh scripts/run_eval.sh path/to/data_root /path/to/dataset /path/to/ckpt"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
|
||||
export DATA_ROOT=$1
|
||||
DATA_PATH=$2
|
||||
CKPT_PATH=$3
|
||||
|
||||
rm -rf eval
|
||||
mkdir ./eval
|
||||
cp ./*.py ./eval
|
||||
cp ./*.yaml ./eval
|
||||
cp -r ./src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start testing"
|
||||
env > env.log
|
||||
python eval.py \
|
||||
--data_root=$DATA_ROOT \
|
||||
--data_lst=$DATA_PATH \
|
||||
--ckpt_file=$CKPT_PATH #> log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_LIST_FILE] [IMAGE_PATH] [MASK_PATH] [DEVICE_ID]
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
model=$(get_real_path $1)
|
||||
data_list_file=$(get_real_path $2)
|
||||
image_path=$(get_real_path $3)
|
||||
mask_path=$(get_real_path $4)
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
elif [ $# == 4 ]; then
|
||||
if [ ! -z $device_id ]; then
|
||||
device_id=$device_id
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $model
|
||||
echo $image_path
|
||||
echo $mask_path
|
||||
echo $device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
cd ../ascend310_infer || exit
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
sh build.sh &> build.log
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
cd - || exit
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
../ascend310_infer/out/main --image_list=$data_list_file --mindir_path=$model --dataset_path=$image_path --device_id=$device_id &> infer.log
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
function cal_acc()
|
||||
{
|
||||
python ../postprocess.py --image_list=$data_list_file --data_path=$image_path --mask_path=$mask_path --result_path=result_Files &> acc.log
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "calculate accuracy failed"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
compile_app
|
||||
infer
|
||||
cal_acc
|
|
@ -1,39 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh scripts/run_standalone_train.sh DEVICE_ID"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
export DEVICE_ID=$1
|
||||
train_path=train_standalone${DEVICE_ID}
|
||||
|
||||
if [ -d ${train_path} ]; then
|
||||
rm -rf ${train_path}
|
||||
fi
|
||||
mkdir -p ${train_path}
|
||||
cp -r ./src ${train_path}
|
||||
cp ./train.py ${train_path}
|
||||
cp ./*.yaml ${train_path}
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
|
||||
cd ${train_path}|| exit
|
||||
python train.py > log 2>&1 &
|
||||
cd ..
|
|
@ -1,44 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh scripts/run_standalone_train_gpu.sh DEVICE_ID"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=$1
|
||||
PROJECT_DIR=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
train_path=train_standalone${DEVICE_ID}
|
||||
|
||||
if [ -d ${train_path} ]; then
|
||||
rm -rf ${train_path}
|
||||
fi
|
||||
mkdir -p ${train_path}
|
||||
cp -r ./src ${train_path}
|
||||
cp ./train.py ${train_path}
|
||||
cp ./*.yaml ${train_path}
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
|
||||
cd ${train_path}|| exit
|
||||
|
||||
CONFIG_FILE="$PROJECT_DIR/../gpu_default_config.yaml"
|
||||
|
||||
nohup python train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--device_target=GPU > log 2>&1 &
|
||||
cd ..
|
|
@ -1,53 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh scripts/run_train.sh [device_num][RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $2 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=$1
|
||||
export RANK_SIZE=$1
|
||||
RANK_TABLE_FILE=$(realpath $2)
|
||||
export RANK_TABLE_FILE
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<$1; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
cp ./*.yaml ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -1,76 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
||||
seg_schema = {"file_name": {"type": "string"}, "label": {"type": "bytes"}, "data": {"type": "bytes"}}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('mindrecord')
|
||||
|
||||
parser.add_argument('--data_root', type=str, default='../../VOC2012', help='root path of data')
|
||||
parser.add_argument('--data_lst', type=str, default='ImageSets/Segmentation/trainval.txt', help='list of data')
|
||||
parser.add_argument('--dst_path', type=str, default='./mindname.mindrecord', help='save path of mindrecords')
|
||||
parser.add_argument('--num_shards', type=int, default=1, help='number of shards')
|
||||
parser.add_argument('--shuffle', type=bool, default=True, help='shuffle or not')
|
||||
|
||||
parser_args, _ = parser.parse_known_args()
|
||||
return parser_args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
data_list = []
|
||||
with open(args.data_lst) as f:
|
||||
lines = f.readlines()
|
||||
if args.shuffle:
|
||||
np.random.shuffle(lines)
|
||||
|
||||
dst_dir = '/'.join(args.dst_path.split('/')[:-1])
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
|
||||
print('number of samples:', len(lines))
|
||||
writer = FileWriter(file_name=args.dst_path, shard_num=args.num_shards)
|
||||
writer.add_schema(seg_schema, "seg_schema")
|
||||
cnt = 0
|
||||
|
||||
for l in lines:
|
||||
img_path = l.split(' ')[0].strip('\n')
|
||||
label_path = l.split(' ')[1].strip('\n')
|
||||
|
||||
sample_ = {"file_name": img_path.split('/')[-1]}
|
||||
|
||||
with open(os.path.join(args.data_root, img_path), 'rb') as f:
|
||||
sample_['data'] = f.read()
|
||||
with open(os.path.join(args.data_root, label_path), 'rb') as f:
|
||||
sample_['label'] = f.read()
|
||||
data_list.append(sample_)
|
||||
cnt += 1
|
||||
if cnt % 1000 == 0:
|
||||
writer.write_raw_data(data_list)
|
||||
print('number of samples written:', cnt)
|
||||
data_list = []
|
||||
|
||||
if data_list:
|
||||
writer.write_raw_data(data_list)
|
||||
writer.commit()
|
||||
print('number of samples written:', cnt)
|
|
@ -1,94 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import mindspore.dataset as de
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
|
||||
class SegDataset:
|
||||
def __init__(self,
|
||||
image_mean,
|
||||
image_std,
|
||||
data_file='',
|
||||
batch_size=32,
|
||||
crop_size=512,
|
||||
max_scale=2.0,
|
||||
min_scale=0.5,
|
||||
ignore_label=255,
|
||||
num_classes=21,
|
||||
num_readers=2,
|
||||
num_parallel_calls=4,
|
||||
shard_id=None,
|
||||
shard_num=None):
|
||||
|
||||
self.data_file = data_file
|
||||
self.batch_size = batch_size
|
||||
self.crop_size = crop_size
|
||||
self.image_mean = np.array(image_mean, dtype=np.float32)
|
||||
self.image_std = np.array(image_std, dtype=np.float32)
|
||||
self.max_scale = max_scale
|
||||
self.min_scale = min_scale
|
||||
self.ignore_label = ignore_label
|
||||
self.num_classes = num_classes
|
||||
self.num_readers = num_readers
|
||||
self.num_parallel_calls = num_parallel_calls
|
||||
self.shard_id = shard_id
|
||||
self.shard_num = shard_num
|
||||
assert max_scale > min_scale
|
||||
|
||||
def preprocess_(self, image, label):
|
||||
# bgr image
|
||||
image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
sc = np.random.uniform(self.min_scale, self.max_scale)
|
||||
new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
|
||||
image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
|
||||
label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
image_out = (image_out - self.image_mean) / self.image_std
|
||||
h_, w_ = max(new_h, self.crop_size), max(new_w, self.crop_size)
|
||||
pad_h, pad_w = h_ - new_h, w_ - new_w
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
||||
label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
|
||||
offset_h = np.random.randint(0, h_ - self.crop_size + 1)
|
||||
offset_w = np.random.randint(0, w_ - self.crop_size + 1)
|
||||
image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
|
||||
label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
|
||||
|
||||
if np.random.uniform(0.0, 1.0) > 0.5:
|
||||
image_out = image_out[:, ::-1, :]
|
||||
label_out = label_out[:, ::-1]
|
||||
|
||||
image_out = image_out.transpose((2, 0, 1))
|
||||
image_out = image_out.copy()
|
||||
label_out = label_out.copy()
|
||||
return image_out, label_out
|
||||
|
||||
def get_dataset(self, repeat=1):
|
||||
data_set = de.MindDataset(dataset_file=self.data_file, columns_list=["data", "label"],
|
||||
shuffle=True, num_parallel_workers=self.num_readers,
|
||||
num_shards=self.shard_num, shard_id=self.shard_id)
|
||||
transforms_list = self.preprocess_
|
||||
data_set = data_set.map(operations=transforms_list, input_columns=["data", "label"],
|
||||
output_columns=["data", "label"],
|
||||
num_parallel_workers=self.num_parallel_calls)
|
||||
data_set = data_set.shuffle(buffer_size=self.batch_size * 10)
|
||||
data_set = data_set.batch(self.batch_size, drop_remainder=True)
|
||||
data_set = data_set.repeat(repeat)
|
||||
return data_set
|
|
@ -1,52 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLoss(nn.Cell):
|
||||
def __init__(self, num_cls=21, ignore_label=255, device_num=8):
|
||||
super(SoftmaxCrossEntropyLoss, self).__init__()
|
||||
self.one_hot = P.OneHot(axis=-1)
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.num_cls = num_cls
|
||||
self.ignore_label = ignore_label
|
||||
self.mul = P.Mul()
|
||||
self.sum = P.ReduceSum(False)
|
||||
self.div = P.RealDiv()
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose.shard(((1, 1, 1, device_num),))
|
||||
|
||||
def construct(self, logits, labels):
|
||||
labels_int = self.cast(labels, mstype.int32)
|
||||
labels_int = self.reshape(labels_int, (-1,))
|
||||
logits_ = self.transpose(logits, (0, 2, 3, 1))
|
||||
logits_ = self.reshape(logits_, (-1, self.num_cls))
|
||||
weights = self.not_equal(labels_int, self.ignore_label)
|
||||
weights = self.cast(weights, mstype.float32)
|
||||
one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
|
||||
logits_ = self.cast(logits_, mstype.float32)
|
||||
loss = self.ce(logits_, one_hot_labels)
|
||||
loss = self.mul(weights, loss)
|
||||
loss = self.div(self.sum(loss), self.sum(weights))
|
||||
return loss
|
|
@ -1,130 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
|
||||
global_yaml = '../../default_config.yaml'
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml
|
||||
|
||||
Args:
|
||||
parser: Parent parser
|
||||
cfg: Base configuration
|
||||
helper: Helper description
|
||||
cfg_path: Path to the default yaml config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError('Failed to parse yaml')
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments
|
||||
|
||||
Args:
|
||||
args: command line arguments
|
||||
cfg: Base configuration
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='default name', add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, global_yaml),
|
||||
help='Config file path')
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
|
||||
]
|
|
@ -1,36 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return 'Local Job'
|
|
@ -1,124 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_syn_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local
|
||||
Uploca data from local directory to remote obs in contrast
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_syn_count
|
||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
|
||||
_global_syn_count += 1
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print('from path: ', from_path)
|
||||
print('to path: ', to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print('===finished data synchronization===')
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print('===save flag===')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
print('Finish sync data from {} to {}'.format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print('Dataset downloaded: ', os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
if not os.path.exists(config.load_path):
|
||||
# os.makedirs(config.load_path)
|
||||
print('=' * 20 + 'makedirs')
|
||||
if os.path.isdir(config.load_path):
|
||||
print('=' * 20 + 'makedirs success')
|
||||
else:
|
||||
print('=' * 20 + 'makedirs fail')
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print('Preload downloaded: ', os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print('Workspace downloaded: ', os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print('Start to copy output directory')
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -1,212 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class FCN8s(nn.Cell):
|
||||
def __init__(self, n_class):
|
||||
super().__init__()
|
||||
self.n_class = n_class
|
||||
self.conv1 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=3, out_channels=64,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=64, out_channels=64,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv2 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=64, out_channels=128,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=128, out_channels=128,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv3 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=128, out_channels=256,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=256, out_channels=256,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=256, out_channels=256,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv4 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=256, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv5 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv6 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=512, out_channels=4096,
|
||||
kernel_size=7, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.conv7 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=4096, out_channels=4096,
|
||||
kernel_size=1, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
|
||||
kernel_size=1, weight_init='xavier_uniform')
|
||||
self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
|
||||
kernel_size=4, stride=2, weight_init='xavier_uniform')
|
||||
self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
|
||||
kernel_size=1, weight_init='xavier_uniform')
|
||||
self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
|
||||
kernel_size=4, stride=2, weight_init='xavier_uniform')
|
||||
self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
|
||||
kernel_size=1, weight_init='xavier_uniform')
|
||||
self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
|
||||
kernel_size=16, stride=8, weight_init='xavier_uniform')
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
self.add1 = P.Add()
|
||||
self.add2 = P.Add()
|
||||
|
||||
def set_model_parallel_shard_strategy(self, device_num):
|
||||
self.conv2d_strategy = ((1, 1, 1, device_num), (1, 1, 1, 1))
|
||||
self.bn_strategy = ((1, 1, 1, device_num), (1,), (1,), (1,), (1,))
|
||||
self.relu_strategy = ((1, 1, 1, device_num),)
|
||||
self.maxpool_strategy = ((1, 1, 1, device_num),)
|
||||
self.add_strategy = ((1, 1, 1, device_num), (1, 1, 1, device_num))
|
||||
|
||||
self.conv1.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv1.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv1.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv1.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv1.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv1.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.pool1.max_pool.shard(self.maxpool_strategy)
|
||||
self.conv2.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv2.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv2.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv2.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv2.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv2.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.pool2.max_pool.shard(self.maxpool_strategy)
|
||||
self.conv3.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv3.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv3.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv3.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv3.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv3.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.conv3.cell_list[6].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv3.cell_list[7].bn_train.shard(self.bn_strategy)
|
||||
self.conv3.cell_list[8].relu.shard(self.relu_strategy)
|
||||
self.pool3.max_pool.shard(self.maxpool_strategy)
|
||||
self.conv4.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv4.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv4.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv4.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv4.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv4.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.conv4.cell_list[6].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv4.cell_list[7].bn_train.shard(self.bn_strategy)
|
||||
self.conv4.cell_list[8].relu.shard(self.relu_strategy)
|
||||
self.pool4.max_pool.shard(self.maxpool_strategy)
|
||||
self.conv5.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv5.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv5.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv5.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv5.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv5.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.conv5.cell_list[6].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv5.cell_list[7].bn_train.shard(self.bn_strategy)
|
||||
self.conv5.cell_list[8].relu.shard(self.relu_strategy)
|
||||
self.pool5.max_pool.shard(((1, 1, 1, device_num),))
|
||||
self.conv6.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv6.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv6.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv7.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv7.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv7.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.score_fr.conv2d.shard(self.conv2d_strategy)
|
||||
self.upscore2.conv2d_transpose.shard(self.conv2d_strategy)
|
||||
self.score_pool4.conv2d.shard(self.conv2d_strategy)
|
||||
self.upscore_pool4.conv2d_transpose.shard(self.conv2d_strategy)
|
||||
self.score_pool3.conv2d.shard(self.conv2d_strategy)
|
||||
self.upscore8.conv2d_transpose.shard(self.conv2d_strategy)
|
||||
self.add1.shard(self.add_strategy)
|
||||
self.add2.shard(self.add_strategy)
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.conv1(x)
|
||||
p1 = self.pool1(x1)
|
||||
x2 = self.conv2(p1)
|
||||
p2 = self.pool2(x2)
|
||||
x3 = self.conv3(p2)
|
||||
p3 = self.pool3(x3)
|
||||
x4 = self.conv4(p3)
|
||||
p4 = self.pool4(x4)
|
||||
x5 = self.conv5(p4)
|
||||
p5 = self.pool5(x5)
|
||||
|
||||
x6 = self.conv6(p5)
|
||||
x7 = self.conv7(x6)
|
||||
|
||||
sf = self.score_fr(x7)
|
||||
u2 = self.upscore2(sf)
|
||||
|
||||
s4 = self.score_pool4(p4)
|
||||
f4 = self.add1(s4, u2)
|
||||
u4 = self.upscore_pool4(f4)
|
||||
|
||||
s3 = self.score_pool3(p3)
|
||||
f3 = self.add2(s3, u4)
|
||||
out = self.upscore8(f3)
|
||||
|
||||
return out
|
|
@ -1,654 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
learning rate scheduler
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["LambdaLR", "MultiplicativeLR", "StepLR", "MultiStepLR", "ExponentialLR",
|
||||
"CosineAnnealingLR", "CyclicLR", "CosineAnnealingWarmRestarts", "OneCycleLR"]
|
||||
|
||||
class _WarmUp():
|
||||
def __init__(self, warmup_init_lr):
|
||||
self.warmup_init_lr = warmup_init_lr
|
||||
|
||||
def get_lr(self):
|
||||
# Get learning rate during warmup
|
||||
raise NotImplementedError
|
||||
|
||||
class _LinearWarmUp(_WarmUp):
|
||||
"""
|
||||
linear warmup function
|
||||
"""
|
||||
def __init__(self, lr, warmup_epochs, steps_per_epoch, warmup_init_lr=0):
|
||||
self.base_lr = lr
|
||||
self.warmup_init_lr = warmup_init_lr
|
||||
self.warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
super(_LinearWarmUp, self).__init__(warmup_init_lr)
|
||||
|
||||
def get_warmup_steps(self):
|
||||
return self.warmup_steps
|
||||
|
||||
def get_lr(self, current_step):
|
||||
lr_inc = (float(self.base_lr) - float(self.warmup_init_lr)) / float(self.warmup_steps)
|
||||
lr = float(self.warmup_init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
class _ConstWarmUp(_WarmUp):
|
||||
|
||||
def get_lr(self):
|
||||
return self.warmup_init_lr
|
||||
|
||||
class _LRScheduler():
|
||||
|
||||
def __init__(self, lr, max_epoch, steps_per_epoch):
|
||||
self.base_lr = lr
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.total_steps = int(max_epoch * steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
# Compute learning rate using chainable form of the scheduler
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LambdaLR(_LRScheduler):
|
||||
"""Sets the learning rate to the initial lr times a given function.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
lr_lambda (function or list): A function which computes a multiplicative
|
||||
factor given an integer parameter epoch.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
Example:
|
||||
>>> # Assuming optimizer has two groups.
|
||||
>>> lambda1 = lambda epoch: epoch // 30
|
||||
>>> scheduler = LambdaLR(lr=0.1, lr_lambda=lambda1, steps_per_epoch=5000,
|
||||
>>> max_epoch=90, warmup_epochs=0)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
|
||||
def __init__(self, lr, lr_lambda, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.lr_lambda = lr_lambda
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(LambdaLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
lr = self.base_lr * self.lr_lambda(cur_ep)
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class MultiplicativeLR(_LRScheduler):
|
||||
"""Multiply the learning rate by the factor given
|
||||
in the specified function.
|
||||
|
||||
Args:
|
||||
lr_lambda (function or list): A function which computes a multiplicative
|
||||
factor given an integer parameter epoch,.
|
||||
|
||||
Example:
|
||||
>>> lmbda = lambda epoch: 0.95
|
||||
>>> scheduler = MultiplicativeLR(lr=0.1, lr_lambda=lambda1, steps_per_epoch=5000,
|
||||
>>> max_epoch=90, warmup_epochs=0)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
def __init__(self, lr, lr_lambda, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.lr_lambda = lr_lambda
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(MultiplicativeLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
current_lr = self.base_lr
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
if i % self.steps_per_epoch == 0 and cur_ep > 0:
|
||||
current_lr = current_lr * self.lr_lambda(cur_ep)
|
||||
|
||||
lr = current_lr
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class StepLR(_LRScheduler):
|
||||
"""Decays the learning rate by gamma every epoch_size epochs.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
epoch_size (int): Period of learning rate decay.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
Default: 0.1.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
Example:
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.05 if epoch < 30
|
||||
>>> # lr = 0.005 if 30 <= epoch < 60
|
||||
>>> # lr = 0.0005 if 60 <= epoch < 90
|
||||
>>> # ...
|
||||
>>> scheduler = StepLR(lr=0.1, epoch_size=30, gamma=0.1, steps_per_epoch=5000,
|
||||
>>> max_epoch=90, warmup_epochs=0)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
|
||||
def __init__(self, lr, epoch_size, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.epoch_size = epoch_size
|
||||
self.gamma = gamma
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(StepLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
lr = self.base_lr * self.gamma**(cur_ep // self.epoch_size)
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
"""Decays the learning rate by gamma once the number of epoch reaches one
|
||||
of the milestones.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
milestones (list): List of epoch indices. Must be increasing.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
Default: 0.1.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
Example:
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.05 if epoch < 30
|
||||
>>> # lr = 0.005 if 30 <= epoch < 80
|
||||
>>> # lr = 0.0005 if epoch >= 80
|
||||
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000,
|
||||
>>> max_epoch=90, warmup_epochs=0)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
|
||||
def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.milestones = Counter(milestones)
|
||||
self.gamma = gamma
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
current_lr = self.base_lr
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
|
||||
current_lr = current_lr * self.gamma
|
||||
lr = current_lr
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
"""
|
||||
|
||||
def __init__(self, lr, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.gamma = gamma
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(ExponentialLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
current_lr = self.base_lr
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
if i % self.steps_per_epoch == 0 and i > 0:
|
||||
current_lr = current_lr * self.gamma
|
||||
lr = current_lr
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
r"""Set the learning rate using a cosine annealing schedule, where
|
||||
:math:`\eta_{max}` is set to the initial lr and :math:`T_{cur}` is the
|
||||
number of epochs since the last restart in SGDR:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
||||
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
||||
& T_{cur} \neq (2k+1)T_{max}; \\
|
||||
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
||||
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
||||
& T_{cur} = (2k+1)T_{max}.
|
||||
\end{aligned}
|
||||
|
||||
It has been proposed in
|
||||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
||||
implements the cosine annealing part of SGDR, and not the restarts.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
T_max (int): Maximum number of iterations.
|
||||
eta_min (float): Minimum learning rate. Default: 0.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(self, lr, T_max, steps_per_epoch, max_epoch, warmup_epochs=0, eta_min=0):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(CosineAnnealingLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
current_lr = self.base_lr
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
if i % self.steps_per_epoch == 0 and i > 0:
|
||||
current_lr = self.eta_min + \
|
||||
(self.base_lr - self.eta_min) * (1. + math.cos(math.pi*cur_ep / self.T_max)) / 2
|
||||
|
||||
lr = current_lr
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class CyclicLR(_LRScheduler):
|
||||
r"""Sets the learning rate according to cyclical learning rate policy (CLR).
|
||||
The policy cycles the learning rate between two boundaries with a constant
|
||||
frequency, as detailed in the paper `Cyclical Learning Rates for Training
|
||||
Neural Networks`_. The distance between the two boundaries can be scaled on
|
||||
a per-iteration or per-cycle basis.
|
||||
|
||||
Cyclical learning rate policy changes the learning rate after every batch.
|
||||
|
||||
This class has three built-in policies, as put forth in the paper:
|
||||
|
||||
* "triangular": A basic triangular cycle without amplitude scaling.
|
||||
* "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
|
||||
* "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
|
||||
at each cycle iteration.
|
||||
|
||||
This implementation was adapted from the github repo: `bckenstler/CLR`_
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
max_lr (float): Upper learning rate boundaries in the cycle.
|
||||
Functionally, it defines the cycle amplitude (max_lr - base_lr).
|
||||
The lr at any cycle is the sum of base_lr and some scaling
|
||||
of the amplitude; therefore max_lr may not actually be reached
|
||||
depending on scaling function.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
step_size_up (int): Number of training iterations in the
|
||||
increasing half of a cycle. Default: 2000
|
||||
step_size_down (int): Number of training iterations in the
|
||||
decreasing half of a cycle. If step_size_down is None,
|
||||
it is set to step_size_up. Default: None
|
||||
mode (str): One of {triangular, triangular2, exp_range}.
|
||||
Values correspond to policies detailed above.
|
||||
If scale_fn is not None, this argument is ignored.
|
||||
Default: 'triangular'
|
||||
gamma (float): Constant in 'exp_range' scaling function:
|
||||
gamma**(cycle iterations)
|
||||
Default: 1.0
|
||||
scale_fn (function): Custom scaling policy defined by a single
|
||||
argument lambda function, where
|
||||
0 <= scale_fn(x) <= 1 for all x >= 0.
|
||||
If specified, then 'mode' is ignored.
|
||||
Default: None
|
||||
scale_mode (str): {'cycle', 'iterations'}.
|
||||
Defines whether scale_fn is evaluated on
|
||||
cycle number or cycle iterations (training
|
||||
iterations since start of cycle).
|
||||
Default: 'cycle'
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
|
||||
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lr,
|
||||
max_lr,
|
||||
steps_per_epoch,
|
||||
max_epoch,
|
||||
step_size_up=2000,
|
||||
step_size_down=None,
|
||||
mode='triangular',
|
||||
gamma=1.,
|
||||
scale_fn=None,
|
||||
scale_mode='cycle',
|
||||
warmup_epochs=0):
|
||||
|
||||
self.max_lr = max_lr
|
||||
|
||||
step_size_up = float(step_size_up)
|
||||
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
|
||||
self.total_size = step_size_up + step_size_down
|
||||
self.step_ratio = step_size_up / self.total_size
|
||||
|
||||
if mode not in ['triangular', 'triangular2', 'exp_range'] \
|
||||
and scale_fn is None:
|
||||
raise ValueError('mode is invalid and scale_fn is None')
|
||||
|
||||
self.mode = mode
|
||||
self.gamma = gamma
|
||||
|
||||
if scale_fn is None:
|
||||
mode_map = {
|
||||
'triangular': ['cycle', self._triangular_scale_fn],
|
||||
'triangular2': ['cycle', self._triangular2_scale_fn],
|
||||
'exp_range': ['iterations', self._exp_range_scale_fn]
|
||||
}
|
||||
self.scale_mode = mode_map.get(self.mode)[0]
|
||||
self.scale_fn = mode_map.get(self.mode)[1]
|
||||
else:
|
||||
self.scale_fn = scale_fn
|
||||
self.scale_mode = scale_mode
|
||||
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(CyclicLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def _triangular_scale_fn(self, x):
|
||||
return 1.
|
||||
|
||||
def _triangular2_scale_fn(self, x):
|
||||
return 1 / (2. ** (x - 1))
|
||||
|
||||
def _exp_range_scale_fn(self, x):
|
||||
return self.gamma**(x)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
# Calculates the learning rate at batch index.
|
||||
cycle = math.floor(1 + i / self.total_size)
|
||||
x = 1. + i / self.total_size - cycle
|
||||
if x <= self.step_ratio:
|
||||
scale_factor = x / self.step_ratio
|
||||
else:
|
||||
scale_factor = (x - 1) / (self.step_ratio - 1)
|
||||
|
||||
base_height = (self.max_lr - self.base_lr) * scale_factor
|
||||
if self.scale_mode == 'cycle':
|
||||
lr = self.base_lr + base_height * self.scale_fn(cycle)
|
||||
else:
|
||||
lr = self.base_lr + base_height * self.scale_fn(i)
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class CosineAnnealingWarmRestarts(_LRScheduler):
|
||||
r"""Set the learning rate using a cosine annealing schedule, where
|
||||
:math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` is the
|
||||
number of epochs since the last restart and :math:`T_{i}` is the number
|
||||
of epochs between two warm restarts in SGDR:
|
||||
|
||||
.. math::
|
||||
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
||||
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
||||
|
||||
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
||||
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
||||
|
||||
It has been proposed in
|
||||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
T_0 (int): Number of iterations for the first restart.
|
||||
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
|
||||
eta_min (float, optional): Minimum learning rate. Default: 0.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(self, lr, steps_per_epoch, max_epoch, T_0, T_mult=1, eta_min=0, warmup_epochs=0):
|
||||
if T_0 <= 0 or not isinstance(T_0, int):
|
||||
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
||||
if T_mult < 1 or not isinstance(T_mult, int):
|
||||
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
||||
self.T_0 = T_0
|
||||
self.T_i = T_0
|
||||
self.T_mult = T_mult
|
||||
self.eta_min = eta_min
|
||||
self.T_cur = 0
|
||||
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(CosineAnnealingWarmRestarts, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
if i % self.steps_per_epoch == 0 and i > 0:
|
||||
self.T_cur += 1
|
||||
if self.T_cur >= self.T_i:
|
||||
self.T_cur = self.T_cur - self.T_i
|
||||
self.T_i = self.T_i * self.T_mult
|
||||
|
||||
lr = self.eta_min + (self.base_lr - self.eta_min) * \
|
||||
(1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class OneCycleLR(_LRScheduler):
|
||||
r"""Sets the learning rate of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every batch.
|
||||
This scheduler is not chainable.
|
||||
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
pct_start (float): The percentage of the cycle (in number of steps) spent
|
||||
increasing the learning rate.
|
||||
Default: 0.3
|
||||
anneal_strategy (str): {'cos', 'linear'}
|
||||
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
|
||||
linear annealing.
|
||||
Default: 'cos'
|
||||
div_factor (float): Determines the max learning rate via
|
||||
max_lr = lr * div_factor
|
||||
Default: 25
|
||||
final_div_factor (float): Determines the minimum learning rate via
|
||||
min_lr = lr / final_div_factor
|
||||
Default: 1e4
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
|
||||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||
https://arxiv.org/abs/1708.07120
|
||||
"""
|
||||
def __init__(self,
|
||||
lr,
|
||||
steps_per_epoch,
|
||||
max_epoch,
|
||||
pct_start=0.3,
|
||||
anneal_strategy='cos',
|
||||
div_factor=25.,
|
||||
final_div_factor=1e4,
|
||||
warmup_epochs=0):
|
||||
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(OneCycleLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
self.step_size_up = float(pct_start * self.total_steps) - 1
|
||||
self.step_size_down = float(self.total_steps - self.step_size_up) - 1
|
||||
|
||||
# Validate pct_start
|
||||
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
|
||||
raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
|
||||
|
||||
# Validate anneal_strategy
|
||||
if anneal_strategy not in ['cos', 'linear']:
|
||||
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
|
||||
if anneal_strategy == 'cos':
|
||||
self.anneal_func = self._annealing_cos
|
||||
elif anneal_strategy == 'linear':
|
||||
self.anneal_func = self._annealing_linear
|
||||
|
||||
# Initialize learning rate variables
|
||||
self.max_lr = lr * div_factor
|
||||
self.min_lr = lr / final_div_factor
|
||||
|
||||
def _annealing_cos(self, start, end, pct):
|
||||
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
||||
cos_out = math.cos(math.pi * pct) + 1
|
||||
return end + (start - end) / 2.0 * cos_out
|
||||
|
||||
def _annealing_linear(self, start, end, pct):
|
||||
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
||||
return (end - start) * pct + start
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
if i <= self.step_size_up:
|
||||
lr = self.anneal_func(self.base_lr, self.max_lr, i / self.step_size_up)
|
||||
|
||||
else:
|
||||
down_step_num = i - self.step_size_up
|
||||
lr = self.anneal_func(self.max_lr, self.min_lr, down_step_num / self.step_size_down)
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
|
@ -1,159 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train FCN8s."""
|
||||
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
from src.data import dataset as data_generator
|
||||
from src.loss import loss
|
||||
from src.utils.lr_scheduler import CosineAnnealingLR
|
||||
from src.nets.FCN8s import FCN8s
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
device_num = get_device_num()
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False,
|
||||
device_target=config.device_target, device_id=get_device_id())
|
||||
# init multicards training
|
||||
config.rank = 0
|
||||
config.group_size = 1
|
||||
if device_num > 1:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
if config.parallel_mode in ParallelMode.MODE_LIST:
|
||||
parallel_mode = config.parallel_mode
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
||||
init()
|
||||
config.rank = get_rank()
|
||||
config.group_size = get_group_size()
|
||||
|
||||
# dataset
|
||||
dataset = data_generator.SegDataset(image_mean=config.image_mean,
|
||||
image_std=config.image_std,
|
||||
data_file=config.data_file,
|
||||
batch_size=config.train_batch_size,
|
||||
crop_size=config.crop_size,
|
||||
max_scale=config.max_scale,
|
||||
min_scale=config.min_scale,
|
||||
ignore_label=config.ignore_label,
|
||||
num_classes=config.num_classes,
|
||||
num_readers=2,
|
||||
num_parallel_calls=4,
|
||||
shard_id=config.rank,
|
||||
shard_num=config.group_size)
|
||||
dataset = dataset.get_dataset(repeat=1)
|
||||
|
||||
net = FCN8s(n_class=config.num_classes)
|
||||
if context.get_auto_parallel_context("parallel_mode") in [ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
ParallelMode.AUTO_PARALLEL]:
|
||||
net.set_model_parallel_shard_strategy(device_num)
|
||||
loss_ = loss.SoftmaxCrossEntropyLoss(config.num_classes, config.ignore_label, device_num=device_num)
|
||||
|
||||
# load pretrained vgg16 parameters to init FCN8s
|
||||
if config.ckpt_vgg16:
|
||||
param_vgg = load_checkpoint(config.ckpt_vgg16)
|
||||
is_model_zoo_vgg16 = "layers.0.weight" in param_vgg
|
||||
if is_model_zoo_vgg16:
|
||||
_is_bn_used = "layers.1.gamma" in param_vgg
|
||||
_is_imagenet_used = param_vgg['classifier.6.bias'].shape == (1000,)
|
||||
if not _is_bn_used or not _is_imagenet_used:
|
||||
raise Exception("Please use the vgg16 checkpoint which use BN and trained by ImageNet dataset.")
|
||||
idx = 0
|
||||
param_dict = {}
|
||||
for layer_id in range(1, 6):
|
||||
sub_layer_num = 2 if layer_id < 3 else 3
|
||||
for sub_layer_id in range(sub_layer_num):
|
||||
# conv param
|
||||
y_weight = 'conv{}.{}.weight'.format(layer_id, 3 * sub_layer_id)
|
||||
x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format(layer_id, sub_layer_id + 1)
|
||||
if is_model_zoo_vgg16:
|
||||
x_weight = 'layers.{}.weight'.format(idx)
|
||||
param_dict[y_weight] = param_vgg[x_weight]
|
||||
# BatchNorm param
|
||||
y_gamma = 'conv{}.{}.gamma'.format(layer_id, 3 * sub_layer_id + 1)
|
||||
y_beta = 'conv{}.{}.beta'.format(layer_id, 3 * sub_layer_id + 1)
|
||||
x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format(layer_id, sub_layer_id + 1)
|
||||
x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format(layer_id, sub_layer_id + 1)
|
||||
if is_model_zoo_vgg16:
|
||||
x_gamma = 'layers.{}.gamma'.format(idx + 1)
|
||||
x_beta = 'layers.{}.beta'.format(idx + 1)
|
||||
param_dict[y_gamma] = param_vgg[x_gamma]
|
||||
param_dict[y_beta] = param_vgg[x_beta]
|
||||
idx += 3
|
||||
idx += 1
|
||||
load_param_into_net(net, param_dict)
|
||||
# load pretrained FCN8s
|
||||
elif config.ckpt_pre_trained:
|
||||
param_dict = load_checkpoint(config.ckpt_pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# optimizer
|
||||
iters_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
lr_scheduler = CosineAnnealingLR(config.base_lr,
|
||||
config.train_epochs,
|
||||
iters_per_epoch,
|
||||
config.train_epochs,
|
||||
warmup_epochs=0,
|
||||
eta_min=0)
|
||||
lr = Tensor(lr_scheduler.get_lr())
|
||||
|
||||
# loss scale
|
||||
manager_loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
if config.device_target == "Ascend":
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
|
||||
loss_scale=config.loss_scale)
|
||||
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
|
||||
elif config.device_target == "GPU":
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
|
||||
model = Model(net, loss_fn=loss_, optimizer=optimizer)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
# callback for saving ckpts
|
||||
time_cb = TimeMonitor(data_size=iters_per_epoch)
|
||||
loss_cb = LossMonitor()
|
||||
cbs = [time_cb, loss_cb]
|
||||
|
||||
if config.rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=config.model, directory=config.ckpt_dir, config=config_ck)
|
||||
cbs.append(ckpoint_cb)
|
||||
|
||||
model.train(config.train_epochs, dataset, callbacks=cbs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
|
@ -1,195 +0,0 @@
|
|||
# Contents
|
||||
|
||||
- [MCNN Description](#mcnn-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [MCNN Description](#contents)
|
||||
|
||||
MCNN was a Multi-column Convolution Neural Network which can estimate crowd number accurately in a single image from almost any perspective.
|
||||
|
||||
[Paper](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Zhang_Single-Image_Crowd_Counting_CVPR_2016_paper.pdf): Yingying Zhang, Desen Zhou, Siqin Chen, Shenghua Gao, Yi Ma. Single-Image Crowd Counting via Multi-Column Convolutional Neural Network.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
MCNN contains three parallel CNNs whose filters are with local receptive fields of different sizes. For simplification, we use the same network structures for all columns (i.e.,conv–pooling–conv–pooling) except for the sizes and numbers of filters. Max pooling is applied for each 2×2 region, and Rectified linear unit (ReLU) is adopted as the activation function because of its good performance for CNNs.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [ShanghaitechA](<https://www.dropbox.com/s/fipgjqxl7uj8hd5/ShanghaiTech.zip?dl=0>)
|
||||
|
||||
```text
|
||||
├─data
|
||||
├─formatted_trainval
|
||||
├─shanghaitech_part_A_patches_9
|
||||
├─train
|
||||
├─train-den
|
||||
├─val
|
||||
├─val-den
|
||||
├─original
|
||||
├─shanghaitech
|
||||
├─part_A_final
|
||||
├─train_data
|
||||
├─images
|
||||
├─ground_truth
|
||||
├─test_data
|
||||
├─images
|
||||
├─ground_truth
|
||||
├─ground_truth_csv
|
||||
```
|
||||
|
||||
- note: formatted_trainval dir is generated by file [create_training_set_shtech](https://github.com/svishwa/crowdcount-mcnn/blob/master/data_preparation/create_training_set_shtech.m)
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware (Ascend)
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```bash
|
||||
# enter script dir, train MCNN example
|
||||
sh run_standalone_train_ascend.sh 0 ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
# enter script dir, evaluate MCNN example
|
||||
sh run_standalone_eval_ascend.sh 0 ./original/shanghaitech/part_A_final/test_data/images ./original/shanghaitech/part_A_final/test_data/ground_truth_csv ./train/ckpt/best.ckpt
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```text
|
||||
├── cv
|
||||
├── MCNN
|
||||
├── README.md // descriptions about MCNN
|
||||
├── scripts
|
||||
│ ├──run_distribute_train.sh // train in distribute
|
||||
│ ├──run_eval.sh // eval in ascend
|
||||
│ ├──run_standalone_train.sh // train in standalone
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──mcnn.py // mcnn architecture
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──data_loader.py // prepare dataset loader(GREY)
|
||||
│ ├──data_loader_3channel.py // prepare dataset loader(RGB)
|
||||
│ ├──evaluate_model.py // evaluate model
|
||||
│ ├──generator_lr.py // generator learning rate
|
||||
│ ├──Mcnn_Callback.py // Mcnn Callback
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── export.py // export script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```python # parameters
|
||||
Major parameters in train.py and config.py as follows:
|
||||
|
||||
--data_path: The absolute full path to the train and evaluation datasets.
|
||||
--epoch_size: Total training epochs.
|
||||
--batch_size: Training batch size.
|
||||
--device_target: Device where the code will be implemented. Optional values are "Ascend", "GPU".
|
||||
--ckpt_path: The absolute full path to the checkpoint file saved after training.
|
||||
--train_path: Training dataset's data
|
||||
--train_gt_path: Training dataset's label
|
||||
--val_path: Testing dataset's data
|
||||
--val_gt_path: Testing dataset's label
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```bash
|
||||
# enter script dir, and run the distribute script
|
||||
sh run_distribute_train.sh ./hccl_table.json ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
# enter script dir, and run the standalone script
|
||||
sh run_standalone_train_ascend.sh 0 ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
```
|
||||
|
||||
After training, the loss value will be achieved as follows:
|
||||
|
||||
```text
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 305, loss is 0.00041025918
|
||||
epoch: 2 step: 305, loss is 3.7117527e-05
|
||||
...
|
||||
epoch: 798 step: 305, loss is 0.000332611
|
||||
epoch: 799 step: 305, loss is 2.6959011e-05
|
||||
epoch: 800 step: 305, loss is 5.6599742e-06
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the current directory.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```bash
|
||||
# enter script dir, and run the script
|
||||
sh run_standalone_eval_ascend.sh 0 ./original/shanghaitech/part_A_final/test_data/images ./original/shanghaitech/part_A_final/test_data/ground_truth_csv ./train/ckpt/best.ckpt
|
||||
```
|
||||
|
||||
You can view the results through the file "eval_log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```text
|
||||
# grep "MAE: " eval_log
|
||||
MAE: 105.87984801910736 MSE: 161.6687899899305
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------|
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
|
||||
| uploaded Date | 06/29/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | ShanghaitechA |
|
||||
| Training Parameters | steps=2439, batch_size = 1 |
|
||||
| Optimizer | Momentum |
|
||||
| outputs | probability |
|
||||
| Speed | 5.79 ms/step |
|
||||
| Total time | 23 mins |
|
||||
| Checkpoint for Fine tuning | 500.94KB (.ckpt file) |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/MCNN | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside ```create_dataset``` function.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -1,100 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train mcnn example ########################
|
||||
train mcnn and get network model files(.ckpt) :
|
||||
python eval.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
from src.dataset import create_dataset
|
||||
from src.mcnn import MCNN
|
||||
from src.data_loader_3channel import ImageDataLoader_3channel
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
import numpy as np
|
||||
|
||||
local_path = '/cache/val_path'
|
||||
local_gt_path = '/cache/val_gt_path'
|
||||
local_ckpt_url = '/cache/ckpt'
|
||||
ckptpath = "obs://lhb1234/MCNN/ckpt"
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MCNN Example')
|
||||
parser.add_argument('--run_offline', type=ast.literal_eval,
|
||||
default=True, help='run in offline is False or True')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.')
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
parser.add_argument('--val_path', required=True,
|
||||
default='/data/mcnn/original/shanghaitech/part_A_final/test_data/images',
|
||||
help='Location of data.')
|
||||
parser.add_argument('--val_gt_path', required=True,
|
||||
default='/data/mcnn/original/shanghaitech/part_A_final/test_data/ground_truth_csv',
|
||||
help='Location of data.')
|
||||
args = parser.parse_args()
|
||||
set_seed(64678)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
if args.run_offline:
|
||||
local_path = args.val_path
|
||||
local_gt_path = args.val_gt_path
|
||||
local_ckpt_url = args.ckpt_path
|
||||
else:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=args.val_path, dst_url=local_path)
|
||||
mox.file.copy_parallel(src_url=args.val_gt_path, dst_url=local_gt_path)
|
||||
mox.file.copy_parallel(src_url=ckptpath, dst_url=local_ckpt_url)
|
||||
|
||||
data_loader_val = ImageDataLoader_3channel(local_path, local_gt_path, shuffle=False, gt_downsample=True,
|
||||
pre_load=True)
|
||||
ds_val = create_dataset(data_loader_val, target=args.device_target, train=False)
|
||||
ds_val = ds_val.batch(1)
|
||||
network = MCNN()
|
||||
|
||||
model_name = local_ckpt_url
|
||||
print(model_name)
|
||||
mae = 0.0
|
||||
mse = 0.0
|
||||
load_checkpoint(model_name, net=network)
|
||||
network.set_train(False)
|
||||
for sample in ds_val.create_dict_iterator():
|
||||
im_data = sample['data']
|
||||
gt_data = sample['gt_density']
|
||||
density_map = network(im_data)
|
||||
gt_count = np.sum(gt_data.asnumpy())
|
||||
et_count = np.sum(density_map.asnumpy())
|
||||
mae += abs(gt_count-et_count)
|
||||
mse += ((gt_count-et_count) * (gt_count-et_count))
|
||||
mae = mae / ds_val.get_dataset_size()
|
||||
mse = np.sqrt(mse / ds_val.get_dataset_size())
|
||||
print('MAE:', mae, ' MSE:', mse)
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
from src.config import crowd_cfg as cfg
|
||||
from src.mcnn import MCNN
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument("--device_id", type=int, default=4, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="mcnn", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# define fusion network
|
||||
network = MCNN()
|
||||
# load network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([args.batch_size, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
export(network, inputs, file_name=args.file_name, file_format=args.file_format)
|
|
@ -1,45 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
export RANK_SIZE=8
|
||||
export DEVICE_NUM=8
|
||||
export RANK_TABLE_FILE=$1
|
||||
export TRAIN_PATH=$2
|
||||
export TRAIN_GT_PATH=$3
|
||||
export VAL_PATH=$4
|
||||
export VAL_GT_PATH=$5
|
||||
export CKPT_PATH=$6
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=${i}
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
|
||||
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
cd ..
|
||||
done
|
|
@ -1,44 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DEVICE_ID] [VAL_PATH] [VAL_GT_PATH] [CKPT_PATH] "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export RANK_SIZE=1
|
||||
export DEVICE_ID=$1
|
||||
export VAL_PATH=$2
|
||||
export VAL_GT_PATH=$3
|
||||
export CKPT_PATH=$4
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python -u eval.py --device_id=$DEVICE_ID --val_path=$VAL_PATH \
|
||||
--val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
cd ..
|
|
@ -1,43 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
ulimit -u unlimited
|
||||
export RANK_SIZE=1
|
||||
export DEVICE_ID=$1
|
||||
export TRAIN_PATH=$2
|
||||
export TRAIN_GT_PATH=$3
|
||||
export VAL_PATH=$4
|
||||
export VAL_GT_PATH=$5
|
||||
export CKPT_PATH=$6
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.
|
||||
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
|
||||
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
fi
|
||||
cd ..
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""This is callback program"""
|
||||
import os
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from src.evaluate_model import evaluate_model
|
||||
|
||||
|
||||
class mcnn_callback(Callback):
|
||||
def __init__(self, net, eval_data, run_offline, ckpt_path):
|
||||
self.net = net
|
||||
self.eval_data = eval_data
|
||||
self.best_mae = 999999
|
||||
self.best_mse = 999999
|
||||
self.best_epoch = 0
|
||||
self.path_url = "/cache/train_output"
|
||||
self.run_offline = run_offline
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
# print(self.net.trainable_params()[0].data.asnumpy()[0][0])
|
||||
mae, mse = evaluate_model(self.net, self.eval_data)
|
||||
cb_param = run_context.original_args()
|
||||
cur_epoch = cb_param.cur_epoch_num
|
||||
if cur_epoch % 2 == 0:
|
||||
if mae < self.best_mae:
|
||||
self.best_mae = mae
|
||||
self.best_mse = mse
|
||||
self.best_epoch = cur_epoch
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
if (device_num == 1) or (device_num == 8 and device_id == 0):
|
||||
# save_checkpoint(self.net, path_url+'/best.ckpt')
|
||||
if self.run_offline:
|
||||
self.path_url = self.ckpt_path
|
||||
if not os.path.exists(self.path_url):
|
||||
os.makedirs(self.path_url, exist_ok=True)
|
||||
save_checkpoint(self.net, os.path.join(self.path_url, 'best.ckpt'))
|
||||
|
||||
log_text = 'EPOCH: %d, MAE: %.1f, MSE: %0.1f' % (cur_epoch, mae, mse)
|
||||
print(log_text)
|
||||
log_text = 'BEST MAE: %0.1f, BEST MSE: %0.1f, BEST EPOCH: %s' \
|
||||
% (self.best_mae, self.best_mse, self.best_epoch)
|
||||
print(log_text)
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
crowd_cfg = edict({
|
||||
'lr': 0.000028,# 0.00001 if device_num == 1; 0.00003 device_num=8
|
||||
'momentum': 0.0,
|
||||
'epoch_size': 800,
|
||||
'batch_size': 1,
|
||||
'buffer_size': 1000,
|
||||
'save_checkpoint_steps': 1,
|
||||
'keep_checkpoint_max': 10,
|
||||
'air_name': "mcnn",
|
||||
})
|
|
@ -1,136 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""image dataloader"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ImageDataLoader():
|
||||
def __init__(self, data_path, gt_path, shuffle=False, gt_downsample=False, pre_load=False):
|
||||
# pre_load: if true, all training and validation images are loaded into CPU RAM for faster processing.
|
||||
# This avoids frequent file reads. Use this only for small datasets.
|
||||
self.data_path = data_path
|
||||
self.gt_path = gt_path
|
||||
self.gt_downsample = gt_downsample
|
||||
self.pre_load = pre_load
|
||||
self.data_files = [filename for filename in os.listdir(data_path) \
|
||||
if os.path.isfile(os.path.join(data_path, filename))]
|
||||
self.data_files.sort()
|
||||
self.shuffle = shuffle
|
||||
if shuffle:
|
||||
random.seed(2468)
|
||||
self.num_samples = len(self.data_files)
|
||||
self.blob_list = {}
|
||||
self.id_list = range(0, self.num_samples)
|
||||
if self.pre_load:
|
||||
print('Pre-loading the data. This may take a while...')
|
||||
idx = 0
|
||||
for fname in self.data_files:
|
||||
|
||||
img = cv2.imread(os.path.join(self.data_path, fname), 0)
|
||||
img = img.astype(np.float32, copy=False)
|
||||
ht = img.shape[0]
|
||||
wd = img.shape[1]
|
||||
ht_1 = (ht // 4) * 4
|
||||
wd_1 = (wd // 4) * 4
|
||||
img = cv2.resize(img, (wd_1, ht_1))
|
||||
|
||||
hang = (256 - ht_1) // 2
|
||||
lie = (256 - wd_1) // 2
|
||||
img = np.pad(img, ((hang, hang), (lie, lie)), 'constant')
|
||||
|
||||
img = img.reshape((1, img.shape[0], img.shape[1]))
|
||||
den = pd.read_csv(os.path.join(self.gt_path, os.path.splitext(fname)[0] + '.csv'), sep=',',
|
||||
header=None).values
|
||||
den = den.astype(np.float32, copy=False)
|
||||
if self.gt_downsample:
|
||||
den = np.pad(den, ((hang, hang), (lie, lie)), 'constant')
|
||||
# print(den.shape)
|
||||
wd_1 = wd_1 // 4
|
||||
ht_1 = ht_1 // 4
|
||||
den = cv2.resize(den, (64, 64))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
else:
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
|
||||
den = den.reshape((1, den.shape[0], den.shape[1]))
|
||||
blob = {}
|
||||
blob['data'] = img
|
||||
blob['gt_density'] = den
|
||||
blob['fname'] = fname
|
||||
self.blob_list[idx] = blob
|
||||
idx = idx + 1
|
||||
if idx % 100 == 0:
|
||||
print('Loaded ', idx, '/', self.num_samples, 'files')
|
||||
|
||||
print('Completed Loading ', idx, 'files')
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
if self.pre_load:
|
||||
random.shuffle(list(self.id_list))
|
||||
else:
|
||||
random.shuffle(list(self.data_files))
|
||||
files = self.data_files
|
||||
id_list = self.id_list
|
||||
|
||||
for idx in id_list:
|
||||
if self.pre_load:
|
||||
blob = self.blob_list[idx]
|
||||
blob['idx'] = idx
|
||||
else:
|
||||
fname = files[idx]
|
||||
img = cv2.imread(os.path.join(self.data_path, fname), 0)
|
||||
img = img.astype(np.float32, copy=False)
|
||||
ht = img.shape[0]
|
||||
wd = img.shape[1]
|
||||
ht_1 = (ht / 4) * 4
|
||||
wd_1 = (wd / 4) * 4
|
||||
img = cv2.resize(img, (wd_1, ht_1))
|
||||
|
||||
hang = (256 - ht_1) // 2
|
||||
lie = (256 - wd_1) // 2
|
||||
img = np.pad(img, ((hang, hang), (lie, lie)), 'constant')
|
||||
|
||||
img = img.reshape((1, img.shape[0], img.shape[1]))
|
||||
den = pd.read_csv(os.path.join(self.gt_path, os.path.splitext(fname)[0] + '.csv'), sep=',',
|
||||
header=None).as_matrix()
|
||||
den = den.astype(np.float32, copy=False)
|
||||
|
||||
if self.gt_downsample:
|
||||
den = np.pad(den, ((hang, hang), (lie, lie)), 'constant')
|
||||
wd_1 = wd_1 / 4
|
||||
ht_1 = ht_1 / 4
|
||||
den = cv2.resize(den, (64, 64))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
else:
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
|
||||
den = den.reshape((1, den.shape[0], den.shape[1]))
|
||||
blob = {}
|
||||
blob['data'] = img
|
||||
blob['gt_density'] = den
|
||||
blob['fname'] = fname
|
||||
|
||||
yield blob
|
||||
|
||||
def get_num_samples(self):
|
||||
return self.num_samples
|
|
@ -1,125 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ImageDataLoader_3channel"""
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ImageDataLoader_3channel():
|
||||
def __init__(self, data_path, gt_path, shuffle=False, gt_downsample=False, pre_load=False):
|
||||
# pre_load: if true, all training and validation images are loaded into CPU RAM for faster processing.
|
||||
# This avoids frequent file reads. Use this only for small datasets.
|
||||
self.data_path = data_path
|
||||
self.gt_path = gt_path
|
||||
self.gt_downsample = gt_downsample
|
||||
self.pre_load = pre_load
|
||||
self.data_files = [filename for filename in os.listdir(data_path) \
|
||||
if os.path.isfile(os.path.join(data_path, filename))]
|
||||
self.data_files.sort()
|
||||
self.shuffle = shuffle
|
||||
if shuffle:
|
||||
random.seed(2468)
|
||||
self.num_samples = len(self.data_files)
|
||||
self.blob_list = {}
|
||||
self.id_list = range(0, self.num_samples)
|
||||
if self.pre_load:
|
||||
print('Pre-loading the data. This may take a while...')
|
||||
idx = 0
|
||||
for fname in self.data_files:
|
||||
|
||||
img = cv2.imread(os.path.join(self.data_path, fname), 0)
|
||||
img = img.astype(np.float32, copy=False)
|
||||
ht = img.shape[0]
|
||||
wd = img.shape[1]
|
||||
ht_1 = (ht // 4) * 4
|
||||
wd_1 = (wd // 4) * 4
|
||||
img = cv2.resize(img, (wd_1, ht_1))
|
||||
|
||||
img = img.reshape((1, img.shape[0], img.shape[1]))
|
||||
den = pd.read_csv(os.path.join(self.gt_path, os.path.splitext(fname)[0] + '.csv'), sep=',',
|
||||
header=None).values
|
||||
den = den.astype(np.float32, copy=False)
|
||||
if self.gt_downsample:
|
||||
# print(den.shape)
|
||||
wd_1 = wd_1 // 4
|
||||
ht_1 = ht_1 // 4
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
else:
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
|
||||
den = den.reshape((1, den.shape[0], den.shape[1]))
|
||||
blob = {}
|
||||
blob['data'] = img
|
||||
blob['gt_density'] = den
|
||||
blob['fname'] = fname
|
||||
self.blob_list[idx] = blob
|
||||
idx = idx + 1
|
||||
if idx % 100 == 0:
|
||||
print('Loaded ', idx, '/', self.num_samples, 'files')
|
||||
|
||||
print('Completed Loading ', idx, 'files')
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
if self.pre_load:
|
||||
random.shuffle(list(self.id_list))
|
||||
else:
|
||||
random.shuffle(list(self.data_files))
|
||||
files = self.data_files
|
||||
id_list = self.id_list
|
||||
|
||||
for idx in id_list:
|
||||
if self.pre_load:
|
||||
blob = self.blob_list[idx]
|
||||
blob['idx'] = idx
|
||||
else:
|
||||
fname = files[idx]
|
||||
img = cv2.imread(os.path.join(self.data_path, fname), 0)
|
||||
img = img.astype(np.float32, copy=False)
|
||||
ht = img.shape[0]
|
||||
wd = img.shape[1]
|
||||
ht_1 = (ht / 4) * 4
|
||||
wd_1 = (wd / 4) * 4
|
||||
img = cv2.resize(img, (wd_1, ht_1))
|
||||
|
||||
img = img.reshape((1, img.shape[0], img.shape[1]))
|
||||
den = pd.read_csv(os.path.join(self.gt_path, os.path.splitext(fname)[0] + '.csv'), sep=',',
|
||||
header=None).as_matrix()
|
||||
den = den.astype(np.float32, copy=False)
|
||||
|
||||
if self.gt_downsample:
|
||||
wd_1 = wd_1 / 4
|
||||
ht_1 = ht_1 / 4
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
else:
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
|
||||
den = den.reshape((1, den.shape[0], den.shape[1]))
|
||||
blob = {}
|
||||
blob['data'] = img
|
||||
blob['gt_density'] = den
|
||||
blob['fname'] = fname
|
||||
|
||||
yield blob
|
||||
|
||||
def get_num_samples(self):
|
||||
return self.num_samples
|
|
@ -1,73 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
import math
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def create_dataset(data_loader, target="Ascend", train=True):
|
||||
datalist = []
|
||||
labellist = []
|
||||
for blob in data_loader:
|
||||
datalist.append(blob['data'])
|
||||
labellist.append(blob['gt_density'])
|
||||
|
||||
class GetDatasetGenerator:
|
||||
def __init__(self):
|
||||
|
||||
self.__data = datalist
|
||||
self.__label = labellist
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.__data[index], self.__label[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__data)
|
||||
|
||||
class MySampler():
|
||||
def __init__(self, dataset, local_rank, world_size):
|
||||
self.__num_data = len(dataset)
|
||||
self.__local_rank = local_rank
|
||||
self.__world_size = world_size
|
||||
self.samples_per_rank = int(math.ceil(self.__num_data / float(self.__world_size)))
|
||||
self.total_num_samples = self.samples_per_rank * self.__world_size
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(range(self.__num_data))
|
||||
indices.extend(indices[:self.total_num_samples-len(indices)])
|
||||
indices = indices[self.__local_rank:self.total_num_samples:self.__world_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_rank
|
||||
|
||||
dataset_generator = GetDatasetGenerator()
|
||||
sampler = MySampler(dataset_generator, local_rank=0, world_size=8)
|
||||
|
||||
if target == "Ascend":
|
||||
# device_num, rank_id = _get_rank_info()
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("DEVICE_ID"))
|
||||
sampler = MySampler(dataset_generator, local_rank=rank_id, world_size=8)
|
||||
if target != "Ascend" or device_num == 1 or (not train):
|
||||
data_set = ds.GeneratorDataset(dataset_generator, ["data", "gt_density"])
|
||||
else:
|
||||
data_set = ds.GeneratorDataset(dataset_generator, ["data", "gt_density"], num_parallel_workers=8,
|
||||
num_shards=device_num, shard_id=rank_id, sampler=sampler)
|
||||
|
||||
return data_set
|
|
@ -1,35 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate model"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def evaluate_model(model, dataset):
|
||||
net = model
|
||||
print("*******************************************************************************************************")
|
||||
mae = 0.0
|
||||
mse = 0.0
|
||||
net.set_train(False)
|
||||
for sample in dataset.create_dict_iterator():
|
||||
im_data = sample['data']
|
||||
gt_data = sample['gt_density']
|
||||
density_map = net(im_data)
|
||||
gt_count = np.sum(gt_data.asnumpy())
|
||||
et_count = np.sum(density_map.asnumpy())
|
||||
mae += abs(gt_count - et_count)
|
||||
mse += ((gt_count - et_count) * (gt_count - et_count))
|
||||
mae = mae / (dataset.get_dataset_size())
|
||||
mse = np.sqrt(mse / dataset.get_dataset_size())
|
||||
return mae, mse
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr_sha(current_step, lr_max, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
current_step(int): current steps of the training
|
||||
lr_max(float): max learning rate
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
decay_epoch_index = [0.5 * total_steps, 0.75 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
else:
|
||||
lr = lr_max * 0.01
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
|
@ -1,119 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""This is mcnn model"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
|
||||
class Conv2d(nn.Cell):
|
||||
"""This is Conv2d model"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, same_padding=False, bn=False):
|
||||
super(Conv2d, self).__init__()
|
||||
padding = int((kernel_size - 1) / 2) if same_padding else 0
|
||||
# padding = 'same' if same_padding else 'valid'
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
|
||||
pad_mode='pad', padding=padding, has_bias=True)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None
|
||||
self.relu = nn.ReLU() if relu else None
|
||||
# # TODO init weights
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, x):
|
||||
"""define Conv2d network"""
|
||||
x = self.conv(x)
|
||||
# if self.bn is not None:
|
||||
# x = self.bn(x)
|
||||
if self.relu is not None:
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""initialize weights"""
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
if isinstance(m, nn.Dense):
|
||||
m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
|
||||
|
||||
def np_to_tensor(x, is_cuda=True, is_training=False):
|
||||
if is_training:
|
||||
v = Tensor(x, mstype.float32)
|
||||
else:
|
||||
v = Tensor(x, mstype.float32) # with torch.no_grad():
|
||||
return v
|
||||
|
||||
|
||||
class MCNN(nn.Cell):
|
||||
'''
|
||||
Multi-column CNN
|
||||
-Implementation of Single Image Crowd Counting via Multi-column CNN (Zhang et al.)
|
||||
'''
|
||||
def __init__(self, bn=False):
|
||||
super(MCNN, self).__init__()
|
||||
|
||||
self.branch1 = nn.SequentialCell(Conv2d(1, 16, 9, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(16, 32, 7, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(32, 16, 7, same_padding=True, bn=bn),
|
||||
Conv2d(16, 8, 7, same_padding=True, bn=bn))
|
||||
|
||||
self.branch2 = nn.SequentialCell(Conv2d(1, 20, 7, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(20, 40, 5, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(40, 20, 5, same_padding=True, bn=bn),
|
||||
Conv2d(20, 10, 5, same_padding=True, bn=bn))
|
||||
|
||||
self.branch3 = nn.SequentialCell(Conv2d(1, 24, 5, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(24, 48, 3, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(48, 24, 3, same_padding=True, bn=bn),
|
||||
Conv2d(24, 12, 3, same_padding=True, bn=bn))
|
||||
|
||||
self.fuse = nn.SequentialCell([Conv2d(30, 1, 1, same_padding=True, bn=bn)])
|
||||
|
||||
##TODO init weights
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, im_data):
|
||||
"""define network"""
|
||||
x1 = self.branch1(im_data)
|
||||
x2 = self.branch2(im_data)
|
||||
x3 = self.branch3(im_data)
|
||||
op = ops.Concat(1)
|
||||
x = op((x1, x2, x3))
|
||||
x = self.fuse(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""initialize weights"""
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
if isinstance(m, nn.Dense):
|
||||
m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
|
@ -1,119 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train mcnn example ########################
|
||||
train mcnn and get network model files(.ckpt) :
|
||||
python train.py
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import numpy as np
|
||||
from mindspore.communication.management import init
|
||||
import mindspore.nn as nn
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from src.data_loader import ImageDataLoader
|
||||
from src.config import crowd_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.mcnn import MCNN
|
||||
from src.generator_lr import get_lr_sha
|
||||
from src.Mcnn_Callback import mcnn_callback
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MCNN Example')
|
||||
parser.add_argument('--run_offline', type=ast.literal_eval,
|
||||
default=True, help='run in offline is False or True')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.')
|
||||
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
|
||||
parser.add_argument('--train_path', required=True, default=None, help='Location of data.')
|
||||
parser.add_argument('--train_gt_path', required=True, default=None, help='Location of data.')
|
||||
parser.add_argument('--val_path', required=True,
|
||||
default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val',
|
||||
help='Location of data.')
|
||||
parser.add_argument('--val_gt_path', required=True,
|
||||
default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den',
|
||||
help='Location of data.')
|
||||
args = parser.parse_args()
|
||||
rand_seed = 64678
|
||||
np.random.seed(rand_seed)
|
||||
|
||||
if __name__ == "__main__":
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
|
||||
print("device_num:", device_num)
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
if args.run_offline:
|
||||
local_data1_url = args.train_path
|
||||
local_data2_url = args.train_gt_path
|
||||
local_data3_url = args.val_path
|
||||
local_data4_url = args.val_gt_path
|
||||
else:
|
||||
import moxing as mox
|
||||
local_data1_url = '/cache/train_path'
|
||||
local_data2_url = '/cache/train_gt_path'
|
||||
local_data3_url = '/cache/val_path'
|
||||
local_data4_url = '/cache/val_gt_path'
|
||||
|
||||
mox.file.copy_parallel(src_url=args.train_path, dst_url=local_data1_url) # pcl
|
||||
mox.file.copy_parallel(src_url=args.train_gt_path, dst_url=local_data2_url) # pcl
|
||||
mox.file.copy_parallel(src_url=args.val_path, dst_url=local_data3_url) # pcl
|
||||
mox.file.copy_parallel(src_url=args.val_gt_path, dst_url=local_data4_url) # pcl
|
||||
|
||||
data_loader = ImageDataLoader(local_data1_url, local_data2_url, shuffle=True, gt_downsample=True, pre_load=True)
|
||||
data_loader_val = ImageDataLoader(local_data3_url, local_data4_url,
|
||||
shuffle=False, gt_downsample=True, pre_load=True)
|
||||
ds_train = create_dataset(data_loader, target=args.device_target)
|
||||
ds_val = create_dataset(data_loader_val, target=args.device_target, train=False)
|
||||
|
||||
ds_train = ds_train.batch(cfg['batch_size'])
|
||||
ds_val = ds_val.batch(1)
|
||||
|
||||
network = MCNN()
|
||||
net_loss = nn.MSELoss(reduction='mean')
|
||||
lr = Tensor(get_lr_sha(0, cfg['lr'], cfg['epoch_size'], ds_train.get_dataset_size()))
|
||||
net_opt = nn.Adam(list(filter(lambda p: p.requires_grad, network.get_parameters())), learning_rate=lr)
|
||||
|
||||
if args.device_target != "Ascend":
|
||||
model = Model(network, net_loss, net_opt)
|
||||
else:
|
||||
model = Model(network, net_loss, net_opt, amp_level="O2")
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
eval_callback = mcnn_callback(network, ds_val, args.run_offline, args.ckpt_path)
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, eval_callback, LossMonitor(1)])
|
||||
if not args.run_offline:
|
||||
mox.file.copy_parallel(src_url='/cache/train_output', dst_url="obs://lhb1234/MCNN/ckpt")
|
|
@ -1,433 +0,0 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [AlexNet Description](#alexnet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Inference Process](#inference-process)
|
||||
- [Export MindIR](#export-mindir)
|
||||
- [Infer on Ascend310](#infer-on-ascend310)
|
||||
- [Result](#result)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## [AlexNet Description](#contents)
|
||||
|
||||
AlexNet was proposed in 2012, one of the most influential neural networks. It got big success in ImageNet Dataset recognition than other models.
|
||||
|
||||
[Paper](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf): Krizhevsky A, Sutskever I, Hinton G E. ImageNet Classification with Deep ConvolutionalNeural Networks. *Advances In Neural Information Processing Systems*. 2012.
|
||||
|
||||
## [Model Architecture](#contents)
|
||||
|
||||
AlexNet composition consists of 5 convolutional layers and 3 fully connected layers. Multiple convolutional kernels can extract interesting features in images and get more accurate classification.
|
||||
|
||||
## [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>)
|
||||
|
||||
- Dataset size:175M,60,000 32*32 colorful images in 10 classes
|
||||
- Train:146M,50,000 images
|
||||
- Test:29.3M,10,000 images
|
||||
- Data format:binary files
|
||||
- Note:Data will be processed in dataset.py
|
||||
- Download the dataset, the directory structure is as follows:
|
||||
|
||||
```bash
|
||||
├─cifar-10-batches-bin
|
||||
│
|
||||
└─cifar-10-verify-bin
|
||||
```
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend/GPU)
|
||||
- Prepare hardware environment with Ascend or GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
|
||||
|
||||
## [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```python
|
||||
# enter script dir, train AlexNet
|
||||
bash run_standalone_train_ascend.sh [DATA_PATH] [CKPT_SAVE_PATH]
|
||||
# example: bash run_standalone_train_ascend.sh /home/DataSet/Cifar10/cifar-10-batches-bin/ /home/model/alexnet/ckpt/
|
||||
|
||||
# enter script dir, evaluate AlexNet
|
||||
bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_NAME]
|
||||
# example: bash run_standalone_eval_ascend.sh /home/DataSet/cifar10/cifar-10-verify-bin /home/model/cv/alxnet/ckpt/checkpoint_alexnet-1_1562.ckpt
|
||||
```
|
||||
|
||||
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```bash
|
||||
# Train 8p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "distribute=True" on default_config.yaml file.
|
||||
# Set "dataset_path='/cache/data'" on default_config.yaml file.
|
||||
# Set "epoch_size: 30" on default_config.yaml file.
|
||||
# (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "distribute=True" on the website UI interface.
|
||||
# Add "dataset_path=/cache/data" on the website UI interface.
|
||||
# Add "epoch_size: 30" on the website UI interface.
|
||||
# (optional)Add "checkpoint_url='s3://dir_to_your_pretrained/'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Prepare model code
|
||||
# (3) Upload or copy your pretrained model to S3 bucket if you want to finetune.
|
||||
# (4) Perform a or b. (suggested option a)
|
||||
# a. First, zip MindRecord dataset to one zip file.
|
||||
# Second, upload your zip dataset to S3 bucket.
|
||||
# b. Upload the original dataset to S3 bucket.
|
||||
# (Data set conversion occurs during training process and costs a lot of time. it happens every time you train.)
|
||||
# (5) Set the code directory to "/path/alexnet" on the website UI interface.
|
||||
# (6) Set the startup file to "train.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
#
|
||||
# Train 1p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "dataset_path='/cache/data'" on default_config.yaml file.
|
||||
# Set "epoch_size: 30" on default_config.yaml file.
|
||||
# (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "dataset_path='/cache/data'" on the website UI interface.
|
||||
# Add "epoch_size: 30" on the website UI interface.
|
||||
# (optional)Add "checkpoint_url='s3://dir_to_your_pretrained/'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Prepare model code
|
||||
# (3) Upload or copy your pretrained model to S3 bucket if you want to finetune.
|
||||
# (4) Perform a or b. (suggested option a)
|
||||
# a. zip MindRecord dataset to one zip file.
|
||||
# Second, upload your zip dataset to S3 bucket.
|
||||
# b. Upload the original dataset to S3 bucket.
|
||||
# (Data set conversion occurs during training process and costs a lot of time. it happens every time you train.)
|
||||
# (5) Set the code directory to "/path/alexnet" on the website UI interface.
|
||||
# (6) Set the startup file to "train.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
#
|
||||
# Eval 1p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_your_trained_model/'" on base_config.yaml file.
|
||||
# Set "checkpoint='./alexnet/alexnet_trained.ckpt'" on default_config.yaml file.
|
||||
# Set "dataset_path='/cache/data'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_your_trained_model/'" on the website UI interface.
|
||||
# Add "checkpoint='./alexnet/alexnet_trained.ckpt'" on the website UI interface.
|
||||
# Add "dataset_path='/cache/data'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Prepare model code
|
||||
# (3) Upload or copy your trained model to S3 bucket.
|
||||
# (4) Perform a or b. (suggested option a)
|
||||
# a. First, zip MindRecord dataset to one zip file.
|
||||
# Second, upload your zip dataset to S3 bucket.
|
||||
# b. Upload the original dataset to S3 bucket.
|
||||
# (Data set conversion occurs during training process and costs a lot of time. it happens every time you train.)
|
||||
# (5) Set the code directory to "/path/alexnet" on the website UI interface.
|
||||
# (6) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
```
|
||||
|
||||
- Export on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start evaluating as follows)
|
||||
|
||||
1. Export s8 multiscale and flip with voc val dataset on modelarts, evaluating steps are as follows:
|
||||
|
||||
```python
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on base_config.yaml file.
|
||||
# Set "file_name='alexnet'" on base_config.yaml file.
|
||||
# Set "file_format='AIR'" on base_config.yaml file.
|
||||
# Set "checkpoint_url='/The path of checkpoint in S3/'" on beta_config.yaml file.
|
||||
# Set "ckpt_file='/cache/checkpoint_path/model.ckpt'" on base_config.yaml file.
|
||||
# Set other parameters on base_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "file_name='alexnet'" on the website UI interface.
|
||||
# Add "file_format='AIR'" on the website UI interface.
|
||||
# Add "checkpoint_url='/The path of checkpoint in S3/'" on the website UI interface.
|
||||
# Add "ckpt_file='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Upload or copy your trained model to S3 bucket.
|
||||
# (3) Set the code directory to "/path/alexnet" on the website UI interface.
|
||||
# (4) Set the startup file to "export.py" on the website UI interface.
|
||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (6) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```bash
|
||||
├── cv
|
||||
├── alexnet
|
||||
├── README.md // descriptions about alexnet
|
||||
├── requirements.txt // package needed
|
||||
├── scripts
|
||||
│ ├──run_standalone_train_gpu.sh // train in gpu
|
||||
│ ├──run_standalone_train_ascend.sh // train in ascend
|
||||
│ ├──run_standalone_eval_gpu.sh // evaluate in gpu
|
||||
│ ├──run_standalone_eval_ascend.sh // evaluate in ascend
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──alexnet.py // alexnet architecture
|
||||
│ └──model_utils
|
||||
│ ├──config.py // Processing configuration parameters
|
||||
│ ├──device_adapter.py // Get cloud ID
|
||||
│ ├──local_adapter.py // Get local ID
|
||||
│ └──moxing_adapter.py // Parameter processing
|
||||
├── default_config.yaml // Training parameter profile(cifar10)
|
||||
├── config_imagenet.yaml // Training parameter profile(imagenet)
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
```python
|
||||
Major parameters in train.py and config.py as follows:
|
||||
|
||||
--data_path: The absolute full path to the train and evaluation datasets.
|
||||
--epoch_size: Total training epochs.
|
||||
--batch_size: Training batch size.
|
||||
--image_height: Image height used as input to the model.
|
||||
--image_width: Image width used as input the model.
|
||||
--device_target: Device where the code will be implemented. Optional values are "Ascend", "GPU".
|
||||
--checkpoint_path: The absolute full path to the checkpoint file saved after training.
|
||||
--data_path: Path where the dataset is saved
|
||||
```
|
||||
|
||||
### [Training Process](#contents)
|
||||
|
||||
#### Training
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```bash
|
||||
python train.py --config_path default_config.yaml --data_path cifar-10-batches-bin --ckpt_path ckpt > log 2>&1 &
|
||||
# or enter script dir, and run the script
|
||||
bash run_standalone_train_ascend.sh /home/DataSet/Cifar10/cifar-10-batches-bin/ /home/model/alexnet/ckpt/
|
||||
```
|
||||
|
||||
After training, the loss value will be achieved as follows:
|
||||
|
||||
```bash
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 1, loss is 2.2791853
|
||||
...
|
||||
epoch: 1 step: 1536, loss is 1.9366643
|
||||
epoch: 1 step: 1537, loss is 1.6983616
|
||||
epoch: 1 step: 1538, loss is 1.0221305
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the current directory.
|
||||
|
||||
- running on GPU
|
||||
|
||||
```bash
|
||||
python train.py --config_path default_config.yaml --device_target "GPU" --data_path cifar-10-batches-bin --ckpt_path ckpt > log 2>&1 &
|
||||
# or enter script dir, and run the script
|
||||
bash run_standalone_train_for_gpu.sh cifar-10-batches-bin ckpt
|
||||
```
|
||||
|
||||
After training, the loss value will be achieved as follows:
|
||||
|
||||
```bash
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 1, loss is 2.3125906
|
||||
...
|
||||
epoch: 30 step: 1560, loss is 0.6687547
|
||||
epoch: 30 step: 1561, loss is 0.20055409
|
||||
epoch: 30 step: 1561, loss is 0.103845775
|
||||
```
|
||||
|
||||
### [Evaluation Process](#contents)
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```bash
|
||||
python eval.py --config_path default_config.yaml --data_path cifar-10-verify-bin --ckpt_path ckpt/checkpoint_alexnet-1_1562.ckpt > eval_log.txt 2>&1 &
|
||||
# or enter script dir, and run the script
|
||||
bash run_standalone_eval_ascend.sh cifar-10-verify-bin ckpt/checkpoint_alexnet-1_1562.ckpt
|
||||
```
|
||||
|
||||
You can view the results through the file "eval_log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```bash
|
||||
# grep "Accuracy: " eval_log
|
||||
'Accuracy': 0.8832
|
||||
```
|
||||
|
||||
- running on GPU
|
||||
|
||||
```bash
|
||||
python eval.py --config_path default_config.yaml --device_target "GPU" --data_path cifar-10-verify-bin --ckpt_path ckpt/checkpoint_alexnet-30_1562.ckpt > eval_log 2>&1 &
|
||||
# or enter script dir, and run the script
|
||||
bash run_standalone_eval_for_gpu.sh cifar-10-verify-bin ckpt/checkpoint_alexnet-30_1562.ckpt
|
||||
```
|
||||
|
||||
You can view the results through the file "eval_log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```bash
|
||||
# grep "Accuracy: " eval_log
|
||||
'Accuracy': 0.88512
|
||||
```
|
||||
|
||||
## [Inference Process](#contents)
|
||||
|
||||
### [Export MindIR](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --config_path [CONFIG_PATH] --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
### [Infer on Ascend310](#contents)
|
||||
|
||||
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
|
||||
Current batch_Size for imagenet2012 dataset can only be set to 1.
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATASET_NAME] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `MINDIR_PATH` specifies path of used "MINDIR" OR "AIR" model.
|
||||
- `DATASET_NAME` specifies datasets used to infer. value can be chosen between 'cifar10' and 'imagenet2012', defaulted is 'cifar10'
|
||||
- `DATASET_PATH` specifies path of cifar10 datasets
|
||||
- `NEED_PREPROCESS` means weather need preprocess or not, it's value is 'y' or 'n', if you choose y, the cifar10 dataset will be processed in bin format, the imagenet2012 dataset will generate label json file.
|
||||
- `DEVICE_ID` is optional, default value is 0.
|
||||
|
||||
### [Result](#contents)
|
||||
|
||||
Inference result is saved in current path, you can find result like this in acc.log file.
|
||||
|
||||
```bash
|
||||
'acc': 0.88772
|
||||
```
|
||||
|
||||
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```bash
|
||||
# Train 8p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "distribute=True" on default_config.yaml file.
|
||||
# Set "data_path='/cache/data'" on default_config.yaml file.
|
||||
# Set "ckpt_path='/cache/train'" on default_config.yaml file.
|
||||
# (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "distribute=True" on the website UI interface.
|
||||
# Add "data_path=/cache/data" on the website UI interface.
|
||||
# Add "ckpt_path=/cache/train" on the website UI interface.
|
||||
# (optional)Add "checkpoint_url='s3://dir_to_your_pretrained/'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Prepare model code
|
||||
# (3) Upload or copy your pretrained model to S3 bucket if you want to finetune.
|
||||
# (4) Upload the original cifar10 dataset to S3 bucket.
|
||||
# (5) Set the code directory to "/path/alexnet" on the website UI interface.
|
||||
# (6) Set the startup file to "train.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
#
|
||||
# Train 1p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "data_path='/cache/data'" on default_config.yaml file.
|
||||
# Set "ckpt_path='/cache/train'" on default_config.yaml file.
|
||||
# (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "data_path=/cache/data" on the website UI interface.
|
||||
# Add "ckpt_path=/cache/train" on the website UI interface.
|
||||
# (optional)Add "checkpoint_url='s3://dir_to_your_pretrained/'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Prepare model code
|
||||
# (3) Upload or copy your pretrained model to S3 bucket if you want to finetune.
|
||||
# (4) Upload the original cifar10 dataset to S3 bucket.
|
||||
# (5) Set the code directory to "/path/alexnet" on the website UI interface.
|
||||
# (6) Set the startup file to "train.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
#
|
||||
# Eval 1p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "data_path='/cache/data'" on default_config.yaml file.
|
||||
# Set "ckpt_file='/cache/train/checkpoint_alexnet-30_1562.ckpt'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "data_path=/cache/data" on the website UI interface.
|
||||
# Add "ckpt_file=/cache/train/checkpoint_alexnet-30_1562.ckpt" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Prepare model code
|
||||
# (3) Upload or copy your trained model to S3 bucket.
|
||||
# (4) Upload the original cifar10 dataset to S3 bucket.
|
||||
# (5) Set the code directory to "/path/alexnet" on the website UI interface.
|
||||
# (6) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
### [Performance](#contents)
|
||||
|
||||
#### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | ------------------------------------------------------------| -------------------------------------------------|
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| uploaded Date | 07/05/2021 (month/day/year) | 17/09/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.2.1 | 1.2.1 |
|
||||
| Dataset | CIFAR-10 | CIFAR-10 |
|
||||
| Training Parameters | epoch=30, steps=1562, batch_size = 32, lr=0.002 | epoch=30, steps=1562, batch_size = 32, lr=0.002 |
|
||||
| Optimizer | Momentum | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||
| outputs | probability | probability |
|
||||
| Loss | 0.08 | 0.01 |
|
||||
| Speed | 7.3 ms/step | 16.8 ms/step |
|
||||
| Total time | 6 mins | 14 mins |
|
||||
| Checkpoint for Fine tuning | 445M (.ckpt file) | 445M (.ckpt file) |
|
||||
| Scripts | [AlexNet Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet) | [AlexNet Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet) |
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside ```create_dataset``` function.
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -1,359 +0,0 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [AlexNet描述](#alexnet描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [推理过程](#推理过程)
|
||||
- [导出MindIR](#导出mindir)
|
||||
- [在Ascend310执行推理](#在ascend310执行推理)
|
||||
- [结果](#结果)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## AlexNet描述
|
||||
|
||||
AlexNet是2012年提出的最有影响力的神经网络之一。该网络在ImageNet数据集识别方面取得了显着的成功。
|
||||
|
||||
[论文](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-concumulational-neural-networks.pdf): Krizhevsky A, Sutskever I, Hinton G E. ImageNet Classification with Deep ConvolutionalNeural Networks. *Advances In Neural Information Processing Systems*. 2012.
|
||||
|
||||
## 模型架构
|
||||
|
||||
AlexNet由5个卷积层和3个全连接层组成。多个卷积核用于提取图像中有趣的特征,从而得到更精确的分类。
|
||||
|
||||
## 数据集
|
||||
|
||||
使用的数据集:[CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>)
|
||||
|
||||
- 数据集大小:175M,共10个类、60,000个32*32彩色图像
|
||||
- 训练集:146M,50,000个图像
|
||||
- 测试集:29.3M,10,000个图像
|
||||
- 数据格式:二进制文件
|
||||
- 注意:数据在dataset.py中处理。
|
||||
- 下载数据集。目录结构如下:
|
||||
|
||||
```bash
|
||||
├─cifar-10-batches-bin
|
||||
│
|
||||
└─cifar-10-verify-bin
|
||||
```
|
||||
|
||||
## 环境要求
|
||||
|
||||
- 硬件(Ascend/GPU)
|
||||
- 准备Ascend或GPU处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
## 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```python
|
||||
# 进入脚本目录,训练AlexNet
|
||||
bash run_standalone_train_ascend.sh [DATA_PATH] [CKPT_SAVE_PATH]
|
||||
# example: bash run_standalone_train_ascend.sh /home/DataSet/Cifar10/cifar-10-batches-bin/ /home/model/alexnet/ckpt/
|
||||
|
||||
# 分布式训练AlexNet
|
||||
|
||||
# 进入脚本目录,评估AlexNet
|
||||
bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_NAME]
|
||||
# example: bash run_standalone_eval_ascend.sh /home/DataSet/cifar10/cifar-10-verify-bin /home/model/cv/alxnet/ckpt/checkpoint_alexnet-1_1562.ckpt
|
||||
```
|
||||
|
||||
- 在 ModelArts 进行训练 (如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||
|
||||
```bash
|
||||
# 在 ModelArts 上使用8卡训练
|
||||
# (1) 执行a或者b
|
||||
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 default_config.yaml 文件中设置 "distribute=True"
|
||||
# 在 default_config.yaml 文件中设置 "data_path='/cache/data'"
|
||||
# 在 default_config.yaml 文件中设置 "ckpt_path='/cache/train'"
|
||||
# (可选)在 default_config.yaml 文件中设置 "checkpoint_url='s3://dir_to_your_pretrained/'"
|
||||
# 在 default_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "distribute=True"
|
||||
# 在网页上设置 "data_path='/cache/data'"
|
||||
# 在网页上设置 "ckpt_path='/cache/train'"
|
||||
# (可选)在网页上设置 "checkpoint_url='s3://dir_to_your_pretrained/'"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 准备模型代码
|
||||
# (3) 如果选择微调您的模型,请上传你的预训练模型到 S3 桶上
|
||||
# (4) 上传原始 cifar10 数据集到 S3 桶上
|
||||
# (5) 在网页上设置你的代码路径为 "/path/alexnet"
|
||||
# (6) 在网页上设置启动文件为 "train.py"
|
||||
# (7) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (8) 创建训练作业
|
||||
#
|
||||
# 在 ModelArts 上使用单卡训练
|
||||
# (1) 执行a或者b
|
||||
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 default_config.yaml 文件中设置 "data_path='/cache/data'"
|
||||
# 在 default_config.yaml 文件中设置 "ckpt_path='/cache/train'"
|
||||
# (可选)在 default_config.yaml 文件中设置 "checkpoint_url='s3://dir_to_your_pretrained/'"
|
||||
# 在 default_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "data_path='/cache/data'"
|
||||
# 在网页上设置 "ckpt_path='/cache/train'"
|
||||
# (可选)在网页上设置 "checkpoint_url='s3://dir_to_your_pretrained/'"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 准备模型代码
|
||||
# (3) 如果选择微调您的模型,上传你的预训练模型到 S3 桶上
|
||||
# (4) 上传原始 cifar10 数据集到 S3 桶上
|
||||
# (5) 在网页上设置你的代码路径为 "/path/alexnet"
|
||||
# (6) 在网页上设置启动文件为 "train.py"
|
||||
# (7) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (8) 创建训练作业
|
||||
#
|
||||
# 在 ModelArts 上使用单卡验证
|
||||
# (1) 执行a或者b
|
||||
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 default_config.yaml 文件中设置 "data_path='/cache/data'"
|
||||
# 在 default_config.yaml 文件中设置 "ckpt_file='/cache/train/checkpoint_alexnet-30_1562.ckpt'"
|
||||
# 在 default_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "data_path='/cache/data'"
|
||||
# 在网页上设置 "ckpt_file='/cache/train/checkpoint_alexnet-30_1562.ckpt'"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 准备模型代码
|
||||
# (3) 上传你训练好的模型到 S3 桶上
|
||||
# (4) 上传原始 cifar10 数据集到 S3 桶上
|
||||
# (5) 在网页上设置你的代码路径为 "/path/alexnet"
|
||||
# (6) 在网页上设置启动文件为 "train.py"
|
||||
# (7) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (8) 创建训练作业
|
||||
```
|
||||
|
||||
- 在 ModelArts 进行导出 (如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||
|
||||
1. 使用voc val数据集评估多尺度和翻转s8。评估步骤如下:
|
||||
|
||||
```python
|
||||
# (1) 执行 a 或者 b.
|
||||
# a. 在 base_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 base_config.yaml 文件中设置 "file_name='alexnet'"
|
||||
# 在 base_config.yaml 文件中设置 "file_format='AIR'"
|
||||
# 在 base_config.yaml 文件中设置 "checkpoint_url='/The path of checkpoint in S3/'"
|
||||
# 在 base_config.yaml 文件中设置 "ckpt_file='/cache/checkpoint_path/model.ckpt'"
|
||||
# 在 base_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "file_name='alexnet'"
|
||||
# 在网页上设置 "file_format='AIR'"
|
||||
# 在网页上设置 "checkpoint_url='/The path of checkpoint in S3/'"
|
||||
# 在网页上设置 "ckpt_file='/cache/checkpoint_path/model.ckpt'"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 上传你的预训练模型到 S3 桶上
|
||||
# (3) 在网页上设置你的代码路径为 "/path/alexnet"
|
||||
# (4) 在网页上设置启动文件为 "export.py"
|
||||
# (5) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (6) 创建训练作业
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── cv
|
||||
├── alexnet
|
||||
├── README.md // AlexNet相关说明
|
||||
├── requirements.txt // 所需要的包
|
||||
├── scripts
|
||||
│ ├──run_standalone_train_gpu.sh // 在GPU中训练
|
||||
│ ├──run_standalone_train_ascend.sh // 在Ascend中训练
|
||||
│ ├──run_standalone_eval_gpu.sh // 在GPU中评估
|
||||
│ ├──run_standalone_eval_ascend.sh // 在Ascend中评估
|
||||
├── src
|
||||
│ ├──dataset.py // 创建数据集
|
||||
│ ├──alexnet.py // AlexNet架构
|
||||
| └──model_utils
|
||||
| ├──config.py // 训练配置
|
||||
| ├──device_adapter.py // 获取云上id
|
||||
| ├──local_adapter.py // 获取本地id
|
||||
| └──moxing_adapter.py // 参数处理
|
||||
├── default_config.yaml // 训练参数配置文件
|
||||
├── config_imagenet.yaml // 训练参数配置文件
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 评估脚本
|
||||
```
|
||||
|
||||
### 脚本参数
|
||||
|
||||
```python
|
||||
train.py和config.py中主要参数如下:
|
||||
|
||||
--data_path:到训练和评估数据集的绝对完整路径。
|
||||
--epoch_size:总训练轮次。
|
||||
--batch_size:训练批次大小。
|
||||
--image_height:图像高度作为模型输入。
|
||||
--image_width:图像宽度作为模型输入。
|
||||
--device_target:实现代码的设备。可选值为"Ascend"、"GPU"。
|
||||
--checkpoint_path:训练后保存的检查点文件的绝对完整路径。
|
||||
--data_path:数据集所在路径
|
||||
```
|
||||
|
||||
### 训练过程
|
||||
|
||||
#### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
python train.py --config_path default_config.yaml --data_path cifar-10-batches-bin --ckpt_path ckpt > log 2>&1 &
|
||||
# 或进入脚本目录,执行脚本
|
||||
bash run_standalone_train_ascend.sh /home/DataSet/Cifar10/cifar-10-batches-bin/ /home/model/alexnet/ckpt/
|
||||
```
|
||||
|
||||
经过训练后,损失值如下:
|
||||
|
||||
```bash
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 1, loss is 2.2791853
|
||||
...
|
||||
epoch: 1 step: 1536, loss is 1.9366643
|
||||
epoch: 1 step: 1537, loss is 1.6983616
|
||||
epoch: 1 step: 1538, loss is 1.0221305
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在当前目录下。
|
||||
|
||||
- GPU环境运行
|
||||
|
||||
```bash
|
||||
python train.py --config_path default_config.yaml --device_target "GPU" --data_path cifar-10-batches-bin --ckpt_path ckpt > log 2>&1 &
|
||||
# 或进入脚本目录,执行脚本
|
||||
bash run_standalone_train_for_gpu.sh cifar-10-batches-bin ckpt
|
||||
```
|
||||
|
||||
经过训练后,损失值如下:
|
||||
|
||||
```bash
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 1, loss is 2.3125906
|
||||
...
|
||||
epoch: 30 step: 1560, loss is 0.6687547
|
||||
epoch: 30 step: 1561, loss is 0.20055409
|
||||
epoch: 30 step: 1561, loss is 0.103845775
|
||||
```
|
||||
|
||||
### 评估过程
|
||||
|
||||
#### 评估
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
python eval.py --config_path default_config.yaml --data_path cifar-10-verify-bin --ckpt_path ckpt/checkpoint_alexnet-1_1562.ckpt > eval_log.txt 2>&1 &
|
||||
#或进入脚本目录,执行脚本
|
||||
bash run_standalone_eval_ascend.sh /home/DataSet/cifar10/cifar-10-verify-bin /home/model/cv/alxnet/ckpt/checkpoint_alexnet-1_1562.ckpt
|
||||
```
|
||||
|
||||
可通过"eval_log”文件查看结果。测试数据集的准确率如下:
|
||||
|
||||
```bash
|
||||
# grep "Accuracy: " eval_log
|
||||
'Accuracy': 0.8832
|
||||
```
|
||||
|
||||
- GPU环境运行
|
||||
|
||||
```bash
|
||||
python eval.py --config_path default_config.yaml --device_target "GPU" --data_path cifar-10-verify-bin --ckpt_path ckpt/checkpoint_alexnet-30_1562.ckpt > eval_log 2>&1 &
|
||||
#或进入脚本目录,执行脚本
|
||||
bash run_standalone_eval_for_gpu.sh cifar-10-verify-bin ckpt/checkpoint_alexnet-30_1562.ckpt
|
||||
```
|
||||
|
||||
可通过"eval_log”文件查看结果。测试数据集的准确率如下:
|
||||
|
||||
```bash
|
||||
# grep "Accuracy: " eval_log
|
||||
'Accuracy': 0.88512
|
||||
```
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 导出MindIR
|
||||
|
||||
```shell
|
||||
python export.py --config_path [CONFIG_PATH] --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
参数ckpt_file为必填项,
|
||||
`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。
|
||||
|
||||
### 在Ascend310执行推理
|
||||
|
||||
在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用minir模型执行推理的示例。
|
||||
目前imagenet2012数据集仅支持batch_Size为1的推理。
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATASET_NAME] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `MINDIR_PATH` mindir文件路径
|
||||
- `DATASET_NAME` 使用的推理数据集名称,默认为`cifar10`,可在`cifar10`或者`imagenet2012`中选择
|
||||
- `DATASET_PATH` 推理数据集路径
|
||||
- `NEED_PREPROCESS` 表示数据集是否需要预处理,可在`y`或者`n`中选择,如果选择`y`,cifar10数据集将被处理为bin格式。
|
||||
- `DEVICE_ID` 可选,默认值为0。
|
||||
|
||||
### 结果
|
||||
|
||||
推理结果保存在脚本执行的当前路径,你可以在acc.log中看到以下精度计算结果。
|
||||
|
||||
```bash
|
||||
'acc': 0.88772
|
||||
```
|
||||
|
||||
## 模型描述
|
||||
|
||||
### 性能
|
||||
|
||||
#### 评估性能
|
||||
|
||||
| 参数 | Ascend | GPU |
|
||||
| -------------------------- | ------------------------------------------------------------| -------------------------------------------------|
|
||||
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 | NV SMX2 V100-32G |
|
||||
| 上传日期 | 2021-07-05 | 2020-09-17 |
|
||||
| MindSpore版本 | 1.2.1 | 1.2.1 |
|
||||
| 数据集 | CIFAR-10 | CIFAR-10 |
|
||||
| 训练参数 | epoch=30, step=1562, batch_size=32, lr=0.002 | epoch=30, step=1562, batch_size=32, lr=0.002 |
|
||||
| 优化器 | 动量 | 动量 |
|
||||
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
|
||||
| 输出 | 概率 | 概率 | 概率 |
|
||||
| 损失 | 0.0016 | 0.01 |
|
||||
| 速度 | 7毫秒/步 | 16.8毫秒/步 |
|
||||
| 总时间 | 6分钟 | 14分钟|
|
||||
| 微调检查点 | 445M (.ckpt文件) | 445M (.ckpt文件) |
|
||||
| 脚本 | [AlexNet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet) | [AlexNet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet) |
|
||||
|
||||
## 随机情况说明
|
||||
|
||||
dataset.py中设置了“create_dataset”函数内的种子。
|
||||
|
||||
## ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -1,14 +0,0 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -1,29 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -1,35 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
std::vector<std::string> GetAllFiles(std::string dir_name);
|
||||
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name);
|
||||
|
||||
#endif
|
|
@ -1,189 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/vision_ascend.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "include/dataset/transforms.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::dataset::vision::Decode;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
using mindspore::dataset::vision::CenterCrop;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(dataset_name, "cifar10", "['cifar10', 'imagenet2012']");
|
||||
DEFINE_string(input0_path, ".", "input0 path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
int load_model(Model *model, std::vector<MSTensor> *model_inputs, std::string mindir_path, int device_id) {
|
||||
if (RealPath(mindir_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(device_id);
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(mindir_path, ModelType::kMindIR, &graph);
|
||||
|
||||
Status ret = model->Build(GraphCell(graph), context);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
*model_inputs = model->GetInputs();
|
||||
if (model_inputs->empty()) {
|
||||
std::cout << "Invalid model, inputs is empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
|
||||
Model model;
|
||||
std::vector<MSTensor> model_inputs;
|
||||
load_model(&model, &model_inputs, FLAGS_mindir_path, FLAGS_device_id);
|
||||
|
||||
std::map<double, double> costTime_map;
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
|
||||
if (FLAGS_dataset_name == "cifar10") {
|
||||
auto input0_files = GetAllFiles(FLAGS_input0_path);
|
||||
if (input0_files.empty()) {
|
||||
std::cout << "ERROR: no input data." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
size_t size = input0_files.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << input0_files[i] <<std::endl;
|
||||
auto input0 = ReadFileToTensor(input0_files[i]);
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
input0.Data().get(), input0.DataSize());
|
||||
|
||||
gettimeofday(&start, nullptr);
|
||||
Status ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
int rst = WriteResult(input0_files[i], outputs);
|
||||
if (rst != 0) {
|
||||
std::cout << "write result failed." << std::endl;
|
||||
return rst;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto input0_files = GetAllInputData(FLAGS_input0_path);
|
||||
if (input0_files.empty()) {
|
||||
std::cout << "ERROR: no input data." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
size_t size = input0_files.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t j = 0; j < input0_files[i].size(); ++j) {
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << input0_files[i][j] <<std::endl;
|
||||
auto decode = Decode();
|
||||
auto resize = Resize({256, 256});
|
||||
auto centercrop = CenterCrop({224, 224});
|
||||
auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.12, 57.375});
|
||||
auto hwc2chw = HWC2CHW();
|
||||
|
||||
Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw});
|
||||
auto imgDvpp = std::make_shared<MSTensor>();
|
||||
SingleOp(ReadFileToTensor(input0_files[i][j]), imgDvpp.get());
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
imgDvpp->Data().get(), imgDvpp->DataSize());
|
||||
gettimeofday(&start, nullptr);
|
||||
Status ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << input0_files[i][j] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
int rst = WriteResult(input0_files[i][j], outputs);
|
||||
if (rst != 0) {
|
||||
std::cout << "write result failed." << std::endl;
|
||||
return rst;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -1,197 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) {
|
||||
std::vector<std::vector<std::string>> ret;
|
||||
|
||||
DIR *dir = OpenDir(dir_name);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
struct dirent *filename;
|
||||
/* read all the files in the dir ~ */
|
||||
std::vector<std::string> sub_dirs;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string d_name = std::string(filename->d_name);
|
||||
// get rid of "." and ".."
|
||||
if (d_name == "." || d_name == ".." || d_name.empty()) {
|
||||
continue;
|
||||
}
|
||||
std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name);
|
||||
struct stat s;
|
||||
lstat(dir_path.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
sub_dirs.emplace_back(dir_path);
|
||||
}
|
||||
std::sort(sub_dirs.begin(), sub_dirs.end());
|
||||
|
||||
(void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret),
|
||||
[](const std::string &d) { return GetAllFiles(d); });
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string dir_name) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dir_name);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string d_name = std::string(filename->d_name);
|
||||
if (d_name == "." || d_name == ".." || d_name.size() <= 3) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
const int INVALID_POINTER = -1;
|
||||
const int ERROR = -2;
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE *outputFile = fopen(outFileName.c_str(), "wb");
|
||||
if (outputFile == nullptr) {
|
||||
std::cout << "open result file " << outFileName << " failed" << std::endl;
|
||||
return INVALID_POINTER;
|
||||
}
|
||||
size_t size = fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
|
||||
if (size != outputSize) {
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
std::cout << "write result file " << outFileName << " failed, write size[" << size <<
|
||||
"] is smaller than output size[" << outputSize << "], maybe the disk is full." << std::endl;
|
||||
return ERROR;
|
||||
}
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_path: './checkpoint/'
|
||||
checkpoint_file: './checkpoint/checkpoint_alexnet-30_1562.ckpt'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
ckpt_path: "/cache/data"
|
||||
ckpt_file: "/cache/data/checkpoint_alexnet-30_1562.ckpt"
|
||||
# ==============================================================================
|
||||
# Training options
|
||||
num_classes: 1000
|
||||
learning_rate: 0.13
|
||||
momentum: 0.9
|
||||
epoch_size: 150
|
||||
batch_size: 256
|
||||
buffer_size: None
|
||||
image_height: 224
|
||||
image_width: 224
|
||||
save_checkpoint_steps: 625
|
||||
keep_checkpoint_max: 10
|
||||
air_name: 'alexnet.air'
|
||||
|
||||
weight_decay: 0.0001
|
||||
loss_scale: 1024
|
||||
is_dynamic_loss_scale: 0
|
||||
|
||||
# Model Description
|
||||
model_name: alexnet
|
||||
file_name: 'alexnet'
|
||||
file_format: 'AIR'
|
||||
|
||||
dataset_name: 'imagenet'
|
||||
sink_size: -1
|
||||
dataset_sink_mode: True
|
||||
device_id: 0
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 2
|
||||
lr: 0.01
|
||||
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
|
@ -1,56 +0,0 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_path: './checkpoint/'
|
||||
checkpoint_file: './checkpoint/checkpoint_alexnet-30_1562.ckpt'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
ckpt_path: "/cache/train"
|
||||
ckpt_file: "/cache/train/checkpoint_alexnet-30_1562.ckpt"
|
||||
# ==============================================================================
|
||||
# Training options
|
||||
epoch_size: 30
|
||||
keep_checkpoint_max: 10
|
||||
num_classes: 10
|
||||
learning_rate: 0.002
|
||||
momentum: 0.9
|
||||
batch_size: 32
|
||||
buffer_size: 1000
|
||||
image_height: 227
|
||||
image_width: 227
|
||||
save_checkpoint_steps: 1562
|
||||
air_name: 'alexnet.air'
|
||||
|
||||
dataset_name: 'cifar10'
|
||||
sink_size: -1
|
||||
dataset_sink_mode: True
|
||||
device_id: 0
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 2
|
||||
lr: 0.01
|
||||
|
||||
# Model Description
|
||||
model_name: alexnet
|
||||
file_name: 'alexnet'
|
||||
file_format: 'AIR'
|
||||
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## eval alexnet example ########################
|
||||
eval alexnet according to model file:
|
||||
python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
|
||||
"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
|
||||
from src.alexnet import AlexNet
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.communication.management import init
|
||||
|
||||
|
||||
def modelarts_process():
|
||||
config.ckpt_path = config.ckpt_file
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_process)
|
||||
def eval_alexnet():
|
||||
print("============== Starting Testing ==============")
|
||||
device_num = get_device_num()
|
||||
if device_num > 1:
|
||||
# context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Davinci', save_graphs=False)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
init()
|
||||
elif config.device_target == "GPU":
|
||||
init()
|
||||
|
||||
if config.dataset_name == 'cifar10':
|
||||
network = AlexNet(config.num_classes, phase='test')
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
opt = nn.Momentum(network.trainable_params(), config.learning_rate, config.momentum)
|
||||
ds_eval = create_dataset_cifar10(config, config.data_path, config.batch_size, status="test", \
|
||||
target=config.device_target)
|
||||
param_dict = load_checkpoint(config.ckpt_path)
|
||||
print("load checkpoint from [{}].".format(config.ckpt_path))
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(False)
|
||||
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
elif config.dataset_name == 'imagenet':
|
||||
network = AlexNet(config.num_classes, phase='test')
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
ds_eval = create_dataset_imagenet(config, config.data_path, config.batch_size, training=False)
|
||||
param_dict = load_checkpoint(config.ckpt_path)
|
||||
print("load checkpoint from [{}].".format(config.ckpt_path))
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(False)
|
||||
model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
|
||||
if ds_eval.get_dataset_size() == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
result = model.eval(ds_eval, dataset_sink_mode=config.dataset_sink_mode)
|
||||
print("result : {}".format(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
eval_alexnet()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue