!12885 Add transfer training to unet and update readme.
From: @c_34 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
1fb56a2481
|
@ -0,0 +1,6 @@
|
|||
ARG FROM_IMAGE_NAME
|
||||
FROM ${FROM_IMAGE_NAME}
|
||||
|
||||
RUN apt install libgl1-mesa-glx -y
|
||||
COPY requirements.txt .
|
||||
RUN pip3.7 install -r requirements.txt
|
|
@ -1,26 +1,30 @@
|
|||
# Contents
|
||||
|
||||
- [Unet Description](#unet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Contents](#contents)
|
||||
- [Unet Description](#unet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [running on Ascend](#running-on-ascend)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [How to use](#how-to-use)
|
||||
- [Inference](#inference)
|
||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [How to use](#how-to-use)
|
||||
- [Inference](#inference)
|
||||
- [Running on Ascend 310](#running-on-ascend-310)
|
||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||
- [Transfer training](#transfer-training)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## [Unet Description](#contents)
|
||||
|
||||
|
@ -28,24 +32,24 @@ Unet Medical model for 2D image segmentation. This implementation is as describe
|
|||
|
||||
[Paper](https://arxiv.org/abs/1505.04597): Olaf Ronneberger, Philipp Fischer, Thomas Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation." *conditionally accepted at MICCAI 2015*. 2015.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
## [Model Architecture](#contents)
|
||||
|
||||
Specifically, the U network structure is proposed in UNET, which can better extract and fuse high-level features and obtain context information and spatial location information. The U network structure is composed of encoder and decoder. The encoder is composed of two 3x3 conv and a 2x2 max pooling iteration. The number of channels is doubled after each down sampling. The decoder is composed of a 2x2 deconv, concat layer and two 3x3 convolutions, and then outputs after a 1x1 convolution.
|
||||
|
||||
# [Dataset](#contents)
|
||||
## [Dataset](#contents)
|
||||
|
||||
Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home)
|
||||
|
||||
- Description: The training and test datasets are two stacks of 30 sections from a serial section Transmission Electron Microscopy (ssTEM) data set of the Drosophila first instar larva ventral nerve cord (VNC). The microcube measures 2 x 2 x 1.5 microns approx., with a resolution of 4x4x50 nm/pixel.
|
||||
- License: You are free to use this data set for the purpose of generating or testing non-commercial image segmentation software. If any scientific publications derive from the usage of this data set, you must cite TrakEM2 and the following publication: Cardona A, Saalfeld S, Preibisch S, Schmid B, Cheng A, Pulokas J, Tomancak P, Hartenstein V. 2010. An Integrated Micro- and Macroarchitectural Analysis of the Drosophila Brain by Computer-Assisted Serial Section Electron Microscopy. PLoS Biol 8(10): e1000502. doi:10.1371/journal.pbio.1000502.
|
||||
- Dataset size:22.5M,
|
||||
- Dataset size:22.5M,
|
||||
- Train:15M, 30 images (Training data contains 2 multi-page TIF files, each containing 30 2D-images. train-volume.tif and train-labels.tif respectly contain data and label.)
|
||||
- Val:(We randomly divide the training data into 5-fold and evaluate the model by across 5-fold cross-validation.)
|
||||
- Test:7.5M, 30 images (Testing data contains 1 multi-page TIF files, each containing 30 2D-images. test-volume.tif respectly contain data.)
|
||||
- Data format:binary files(TIF file)
|
||||
- Note:Data will be processed in src/data_loader.py
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
|
@ -55,32 +59,50 @@ Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home)
|
|||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
## [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
- running on Ascend
|
||||
- Run on Ascend
|
||||
|
||||
```python
|
||||
# run training example
|
||||
python train.py --data_url=/path/to/data/ > train.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_standalone_train.sh [DATASET]
|
||||
```python
|
||||
# run training example
|
||||
python train.py --data_url=/path/to/data/ > train.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_standalone_train.sh [DATASET]
|
||||
|
||||
# run distributed training example
|
||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
|
||||
# run distributed training example
|
||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
|
||||
|
||||
# run evaluation example
|
||||
python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
|
||||
```
|
||||
# run evaluation example
|
||||
python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
- Run on docker
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
Build docker images(Change version to the one you actually used)
|
||||
|
||||
```text
|
||||
```shell
|
||||
# build docker
|
||||
docker build -t unet:20.1.0 . --build-arg FROM_IMAGE_NAME=ascend-mindspore-arm:20.1.0
|
||||
```
|
||||
|
||||
Create a container layer over the created image and start it
|
||||
|
||||
```shell
|
||||
# start docker
|
||||
bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
||||
```
|
||||
|
||||
Then you can run everything just like on ascend.
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
├── model_zoo
|
||||
├── README.md // descriptions about all the models
|
||||
├── unet
|
||||
|
@ -102,7 +124,7 @@ After installing MindSpore via the official website, you can start training and
|
|||
├── eval.py // evaluation script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
|
||||
|
@ -123,44 +145,44 @@ Parameters for both training and evaluation can be set in config.py
|
|||
'FixedLossScaleManager': 1024.0, # fix loss scale
|
||||
'resume': False, # whether training with pretrain model
|
||||
'resume_ckpt': './', # pretrain model path
|
||||
'transfer_training': False # whether do transfer training
|
||||
'filter_weight': ["final.weight"] # weight name to filter while doing transfer training
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- running on Ascend
|
||||
#### running on Ascend
|
||||
|
||||
```shell
|
||||
python train.py --data_url=/path/to/data/ > train.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_standalone_train.sh [DATASET]
|
||||
```
|
||||
```shell
|
||||
python train.py --data_url=/path/to/data/ > train.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_standalone_train.sh [DATASET]
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
|
||||
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
|
||||
|
||||
```shell
|
||||
```shell
|
||||
# grep "loss is " train.log
|
||||
step: 1, loss is 0.7011719, fps is 0.25025035060906264
|
||||
step: 2, loss is 0.69433594, fps is 56.77693756377044
|
||||
step: 3, loss is 0.69189453, fps is 57.3293877244179
|
||||
step: 4, loss is 0.6894531, fps is 57.840651522059716
|
||||
step: 5, loss is 0.6850586, fps is 57.89903776054361
|
||||
step: 6, loss is 0.6777344, fps is 58.08073627299014
|
||||
...
|
||||
step: 597, loss is 0.19030762, fps is 58.28088370287449
|
||||
step: 598, loss is 0.19958496, fps is 57.95493929352674
|
||||
step: 599, loss is 0.18371582, fps is 58.04039977720966
|
||||
step: 600, loss is 0.22070312, fps is 56.99692546024671
|
||||
```
|
||||
|
||||
# grep "loss is " train.log
|
||||
step: 1, loss is 0.7011719, fps is 0.25025035060906264
|
||||
step: 2, loss is 0.69433594, fps is 56.77693756377044
|
||||
step: 3, loss is 0.69189453, fps is 57.3293877244179
|
||||
step: 4, loss is 0.6894531, fps is 57.840651522059716
|
||||
step: 5, loss is 0.6850586, fps is 57.89903776054361
|
||||
step: 6, loss is 0.6777344, fps is 58.08073627299014
|
||||
...
|
||||
step: 597, loss is 0.19030762, fps is 58.28088370287449
|
||||
step: 598, loss is 0.19958496, fps is 57.95493929352674
|
||||
step: 599, loss is 0.18371582, fps is 58.04039977720966
|
||||
step: 600, loss is 0.22070312, fps is 56.99692546024671
|
||||
The model checkpoint will be saved in the current directory.
|
||||
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the current directory.
|
||||
|
||||
### Distributed Training
|
||||
#### Distributed Training
|
||||
|
||||
```shell
|
||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
|
||||
|
@ -183,28 +205,26 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329
|
|||
|
||||
- evaluation on ISBI dataset when running on Ascend
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet/ckpt_unet_medical_adam-48_600.ckpt".
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet/ckpt_unet_medical_adam-48_600.ckpt".
|
||||
|
||||
```shell
|
||||
python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
|
||||
```
|
||||
```shell
|
||||
python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```shell
|
||||
```shell
|
||||
# grep "Cross valid dice coeff is:" eval.log
|
||||
============== Cross valid dice coeff is: {'dice_coeff': 0.9085704886070473}
|
||||
```
|
||||
|
||||
# grep "Cross valid dice coeff is:" eval.log
|
||||
============== Cross valid dice coeff is: {'dice_coeff': 0.9085704886070473}
|
||||
## [Model Description](#contents)
|
||||
|
||||
```
|
||||
### [Performance](#contents)
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## Performance
|
||||
|
||||
### Evaluation Performance
|
||||
#### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------ |
|
||||
|
@ -225,105 +245,65 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329
|
|||
| Checkpoint for Fine tuning | 355.11M (.ckpt file) |
|
||||
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
|
||||
|
||||
## [How to use](#contents)
|
||||
### [How to use](#contents)
|
||||
|
||||
### Inference
|
||||
#### Inference
|
||||
|
||||
If you need to use the trained model to perform inference on multiple hardware platforms, such as Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example:
|
||||
|
||||
- Running on Ascend
|
||||
##### Running on Ascend 310
|
||||
|
||||
```python
|
||||
Export MindIR
|
||||
|
||||
# Set context
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",save_graphs=True,device_id=device_id)
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
# Load unseen dataset for inference
|
||||
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
|
||||
The ckpt_file parameter is required,
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
# Define model and Load pre-trained model
|
||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||
param_dict= load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net , param_dict)
|
||||
criterion = CrossEntropyWithLogits()
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
Before performing inference, the MINDIR file must be exported by export script on the 910 environment.
|
||||
Current batch_size can only be set to 1.
|
||||
|
||||
# Make predictions on the unseen dataset
|
||||
print("============== Starting Evaluating ============")
|
||||
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
|
||||
print("============== Cross valid dice coeff is:", dice_score)
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
```
|
||||
`DEVICE_ID` is optional, default value is 0.
|
||||
|
||||
- Running on Ascend 310
|
||||
Inference result is saved in current path, you can find result in acc.log file.
|
||||
|
||||
Export MindIR
|
||||
```text
|
||||
Cross valid dice coeff is: 0.9054352151297033
|
||||
```
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
#### Continue Training on the Pretrained Model
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
Set options `resume` to True in `config.py`, and set `resume_ckpt` to the path of your checkpoint. e.g.
|
||||
|
||||
Before performing inference, the MINDIR file must be exported by export script on the 910 environment.
|
||||
Current batch_size can only be set to 1.
|
||||
```python
|
||||
'resume': True,
|
||||
'resume_ckpt': 'ckpt_0/ckpt_unet_medical_adam_1-1_600.ckpt',
|
||||
'transfer_training': False,
|
||||
'filter_weight': ["final.weight"]
|
||||
```
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
```
|
||||
#### Transfer training
|
||||
|
||||
`DEVICE_ID` is optional, default value is 0.
|
||||
Do the same thing as resuming traing above. In addition, set `transfer_training` to True. The `filter_weight` shows the weights which will be filtered for different dataset. Usually, the default value of `filter_weight` don't need to be changed. The default values includes the weights which depends on the class number. e.g.
|
||||
|
||||
Inference result is saved in current path, you can find result in acc.log file.
|
||||
```python
|
||||
'resume': True,
|
||||
'resume_ckpt': 'ckpt_0/ckpt_unet_medical_adam_1-1_600.ckpt',
|
||||
'transfer_training': True,
|
||||
'filter_weight': ["final.weight"]
|
||||
```
|
||||
|
||||
```text
|
||||
Cross valid dice coeff is: 0.9054352151297033
|
||||
```
|
||||
|
||||
### Continue Training on the Pretrained Model
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```python
|
||||
# Define model
|
||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||
# Continue training if set 'resume' to be True
|
||||
if cfg['resume']:
|
||||
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# Load dataset
|
||||
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
|
||||
train_data_size = train_dataset.get_dataset_size()
|
||||
|
||||
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
|
||||
loss_scale=cfg['loss_scale'])
|
||||
criterion = CrossEntropyWithLogits()
|
||||
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
|
||||
|
||||
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
|
||||
|
||||
|
||||
# Set callbacks
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||
keep_checkpoint_max=cfg['keep_checkpoint_max'])
|
||||
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
|
||||
directory='./ckpt_{}/'.format(device_id),
|
||||
config=ckpt_config)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
|
||||
dataset_sink_mode=False)
|
||||
print("============== End Training ==============")
|
||||
```
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
In data_loader.py, we set the seed inside “_get_val_train_indices" function. We also use random seed in train.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
|
@ -1,43 +1,45 @@
|
|||
# 目录
|
||||
# Unet
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [U-Net说明](#u-net说明)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [Unet](#unet)
|
||||
- [U-Net说明](#u-net说明)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [用法](#用法)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [用法](#用法-1)
|
||||
- [推理](#推理)
|
||||
- [继续训练预训练模型](#继续训练预训练模型)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [用法](#用法-1)
|
||||
- [推理](#推理)
|
||||
- [Ascend 310环境运行](#ascend-310环境运行)
|
||||
- [继续训练预训练模型](#继续训练预训练模型)
|
||||
- [迁移学习](#迁移学习)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# U-Net说明
|
||||
## U-Net说明
|
||||
|
||||
U-Net医学模型基于二维图像分割。实现方式见论文[UNet:Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)。在2015年ISBI细胞跟踪竞赛中,U-Net获得了许多最佳奖项。论文中提出了一种用于医学图像分割的网络模型和数据增强方法,有效利用标注数据来解决医学领域标注数据不足的问题。U型网络结构也用于提取上下文和位置信息。
|
||||
|
||||
[论文](https://arxiv.org/abs/1505.04597): Olaf Ronneberger, Philipp Fischer, Thomas Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation." *conditionally accepted at MICCAI 2015*. 2015.
|
||||
|
||||
# 模型架构
|
||||
## 模型架构
|
||||
|
||||
具体而言,U-Net的U型网络结构可以更好地提取和融合高层特征,获得上下文信息和空间位置信息。U型网络结构由编码器和解码器组成。编码器由两个3x3卷积和一个2x2最大池化迭代组成。每次下采样后通道数翻倍。解码器由2x2反卷积、拼接层和2个3x3卷积组成,经过1x1卷积后输出。
|
||||
|
||||
# 数据集
|
||||
## 数据集
|
||||
|
||||
使用的数据集: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home)
|
||||
|
||||
|
@ -51,7 +53,7 @@ U-Net医学模型基于二维图像分割。实现方式见论文[UNet:Convolu
|
|||
- 数据格式:二进制文件(TIF)
|
||||
- 注意:数据在src/data_loader.py中处理
|
||||
|
||||
# 环境要求
|
||||
## 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 准备Ascend处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,审核通过即可获得资源。
|
||||
|
@ -61,7 +63,7 @@ U-Net医学模型基于二维图像分割。实现方式见论文[UNet:Convolu
|
|||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
## 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
|
@ -82,9 +84,27 @@ U-Net医学模型基于二维图像分割。实现方式见论文[UNet:Convolu
|
|||
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
- Docker中运行
|
||||
|
||||
## 脚本及样例代码
|
||||
创建docker镜像(讲版本号换成你实际使用的版本)
|
||||
|
||||
```shell
|
||||
# build docker
|
||||
docker build -t unet:20.1.0 . --build-arg FROM_IMAGE_NAME=ascend-mindspore-arm:20.1.0
|
||||
```
|
||||
|
||||
使用创建好的镜像启动一个容器。
|
||||
|
||||
```shell
|
||||
# start docker
|
||||
bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
||||
```
|
||||
|
||||
然后在容器里的操作就和Ascend平台上是一样的。
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
||||
```path
|
||||
├── model_zoo
|
||||
|
@ -108,7 +128,7 @@ U-Net医学模型基于二维图像分割。实现方式见论文[UNet:Convolu
|
|||
├── eval.py // 评估脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
### 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
|
@ -202,11 +222,11 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329
|
|||
============== Cross valid dice coeff is: {'dice_coeff': 0.9085704886070473}
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
## 模型描述
|
||||
|
||||
## 性能
|
||||
### 性能
|
||||
|
||||
### 评估性能
|
||||
#### 评估性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------ |
|
||||
|
@ -227,103 +247,64 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329
|
|||
| 微调检查点 | 355.11M (.ckpt文件) |
|
||||
| 脚本 | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
|
||||
|
||||
## 用法
|
||||
### 用法
|
||||
|
||||
### 推理
|
||||
#### 推理
|
||||
|
||||
如果您需要使用训练好的模型在Ascend 910、Ascend 310等多个硬件平台上进行推理上进行推理,可参考此[链接](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/migrate_3rd_scripts.html)。下面是一个简单的操作步骤示例:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
##### Ascend 310环境运行
|
||||
|
||||
```python
|
||||
# 设置上下文
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",save_graphs=True,device_id=device_id)
|
||||
导出mindir模型
|
||||
|
||||
# 加载未知数据集进行推理
|
||||
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
# 定义模型并加载预训练模型
|
||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||
param_dict= load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net , param_dict)
|
||||
criterion = CrossEntropyWithLogits()
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
参数`ckpt_file` 是必需的,`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中进行选择。
|
||||
|
||||
# 对未知数据集进行预测
|
||||
print("============== Starting Evaluating ============")
|
||||
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
|
||||
print("============== Cross valid dice coeff is:", dice_score)
|
||||
在执行推理前,MINDIR文件必须在910上通过export.py文件导出。
|
||||
目前仅可处理batch_Size为1。
|
||||
|
||||
```
|
||||
```shell
|
||||
# Ascend310 推理
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- Ascend 310环境运行
|
||||
`DEVICE_ID` 可选,默认值为 0。
|
||||
|
||||
导出mindir模型
|
||||
推理结果保存在当前路径,可在acc.log中看到最终精度结果。
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
```text
|
||||
Cross valid dice coeff is: 0.9054352151297033
|
||||
```
|
||||
|
||||
参数`ckpt_file` 是必需的,`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中进行选择。
|
||||
#### 继续训练预训练模型
|
||||
|
||||
在执行推理前,MINDIR文件必须在910上通过export.py文件导出。
|
||||
目前仅可处理batch_Size为1。
|
||||
在`config.py`里将`resume`设置成True,并将`resume_ckpt`设置成对应的权重文件路径,例如:
|
||||
|
||||
```shell
|
||||
# Ascend310 推理
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
```
|
||||
```python
|
||||
'resume': True,
|
||||
'resume_ckpt': 'ckpt_0/ckpt_unet_medical_adam_1-1_600.ckpt',
|
||||
'transfer_training': False,
|
||||
'filter_weight': ["final.weight"]
|
||||
```
|
||||
|
||||
`DEVICE_ID` 可选,默认值为 0。
|
||||
#### 迁移学习
|
||||
|
||||
推理结果保存在当前路径,可在acc.log中看到最终精度结果。
|
||||
首先像上面讲的那样讲继续训练的权重加载进来。然后将`transfer_training`设置成True。配置中还有一个 `filter_weight`参数,用于将一些不能适用于不同数据集的权重过滤掉。通常这个`filter_weight`的参数不需要修改,其默认值通常是和模型的分类数相关的参数。例如:
|
||||
|
||||
```text
|
||||
Cross valid dice coeff is: 0.9054352151297033
|
||||
```
|
||||
```python
|
||||
'resume': True,
|
||||
'resume_ckpt': 'ckpt_0/ckpt_unet_medical_adam_1-1_600.ckpt',
|
||||
'transfer_training': True,
|
||||
'filter_weight': ["final.weight"]
|
||||
```
|
||||
|
||||
### 继续训练预训练模型
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```python
|
||||
# 定义模型
|
||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||
#如果'resume'为True,则继续训练
|
||||
if cfg['resume']:
|
||||
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# 加载数据集
|
||||
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
|
||||
train_data_size = train_dataset.get_dataset_size()
|
||||
|
||||
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
|
||||
loss_scale=cfg['loss_scale'])
|
||||
criterion = CrossEntropyWithLogits()
|
||||
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
|
||||
|
||||
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
|
||||
|
||||
|
||||
# 设置回调
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||
keep_checkpoint_max=cfg['keep_checkpoint_max'])
|
||||
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
|
||||
directory='./ckpt_{}/'.format(device_id),
|
||||
config=ckpt_config)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
|
||||
dataset_sink_mode=False)
|
||||
print("============== End Training ==============")
|
||||
```
|
||||
|
||||
# 随机情况说明
|
||||
## 随机情况说明
|
||||
|
||||
dataset.py中设置了“seet_sed”函数内的种子,同时还使用了train.py中的随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
## ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
|
||||
docker_image=$1
|
||||
data_dir=$2
|
||||
model_dir=$3
|
||||
|
||||
docker run -it --ipc=host \
|
||||
--device=/dev/davinci0 \
|
||||
--device=/dev/davinci1 \
|
||||
--device=/dev/davinci2 \
|
||||
--device=/dev/davinci3 \
|
||||
--device=/dev/davinci4 \
|
||||
--device=/dev/davinci5 \
|
||||
--device=/dev/davinci6 \
|
||||
--device=/dev/davinci7 \
|
||||
--device=/dev/davinci_manager \
|
||||
--device=/dev/devmm_svm \
|
||||
--device=/dev/hisi_hdc \
|
||||
--privileged \
|
||||
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons \
|
||||
-v ${data_dir}:${data_dir} \
|
||||
-v ${model_dir}:${model_dir} \
|
||||
-v /var/log/npu/conf/slog/slog.conf:/var/log/npu/conf/slog/slog.conf \
|
||||
-v /var/log/npu/slog/:/var/log/npu/slog/ \
|
||||
-v /var/log/npu/profiling/:/var/log/npu/profiling \
|
||||
-v /var/log/npu/dump/:/var/log/npu/dump \
|
||||
-v /var/log/npu/:/usr/slog ${docker_image} \
|
||||
/bin/bash
|
|
@ -32,6 +32,8 @@ cfg_unet_medical = {
|
|||
|
||||
'resume': False,
|
||||
'resume_ckpt': './',
|
||||
'transfer_training': False,
|
||||
'filter_weight': ['outc.weight', 'outc.bias']
|
||||
}
|
||||
|
||||
cfg_unet_nested = {
|
||||
|
@ -56,6 +58,8 @@ cfg_unet_nested = {
|
|||
|
||||
'resume': False,
|
||||
'resume_ckpt': './',
|
||||
'transfer_training': False,
|
||||
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
|
||||
}
|
||||
|
||||
cfg_unet_nested_cell = {
|
||||
|
@ -81,6 +85,8 @@ cfg_unet_nested_cell = {
|
|||
|
||||
'resume': False,
|
||||
'resume_ckpt': './',
|
||||
'transfer_training': False,
|
||||
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
|
||||
}
|
||||
|
||||
cfg_unet_simple = {
|
||||
|
@ -102,6 +108,8 @@ cfg_unet_simple = {
|
|||
|
||||
'resume': False,
|
||||
'resume_ckpt': './',
|
||||
'transfer_training': False,
|
||||
'filter_weight': ["final.weight"]
|
||||
}
|
||||
|
||||
cfg_unet = cfg_unet_medical
|
||||
|
|
|
@ -68,6 +68,15 @@ class StepLossTimeMonitor(Callback):
|
|||
print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format(
|
||||
cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True)
|
||||
|
||||
|
||||
def mask_to_image(mask):
|
||||
return Image.fromarray((mask * 255).astype(np.uint8))
|
||||
|
||||
|
||||
def filter_checkpoint_parameter_by_list(param_dict, filter_list):
|
||||
"""remove useless parameters according to filter_list"""
|
||||
for key in list(param_dict.keys()):
|
||||
for name in filter_list:
|
||||
if name in key:
|
||||
print("Delete parameter from checkpoint: ", key)
|
||||
del param_dict[key]
|
||||
break
|
||||
|
|
|
@ -30,7 +30,7 @@ from src.unet_medical import UNetMedical
|
|||
from src.unet_nested import NestedUNet, UNet
|
||||
from src.data_loader import create_dataset, create_cell_nuclei_dataset
|
||||
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
|
||||
from src.utils import StepLossTimeMonitor
|
||||
from src.utils import StepLossTimeMonitor, filter_checkpoint_parameter_by_list
|
||||
from src.config import cfg_unet
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
@ -45,7 +45,6 @@ def train_net(data_dir,
|
|||
lr=0.0001,
|
||||
run_distribute=False,
|
||||
cfg=None):
|
||||
|
||||
rank = 0
|
||||
group_size = 1
|
||||
if run_distribute:
|
||||
|
@ -69,6 +68,8 @@ def train_net(data_dir,
|
|||
|
||||
if cfg['resume']:
|
||||
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
||||
if cfg['transfer_training']:
|
||||
filter_checkpoint_parameter_by_list(param_dict, cfg['filter_weight'])
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
if 'use_ds' in cfg and cfg['use_ds']:
|
||||
|
|
Loading…
Reference in New Issue