Add ModelZoo Network: Unet.

This commit is contained in:
zhanghuiyao 2020-09-17 09:33:16 +08:00
parent 4d6bbd1218
commit 776eb28e6e
13 changed files with 1096 additions and 0 deletions

View File

@ -0,0 +1,273 @@
# Contents
- [Unet Description](#unet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [How to use](#how-to-use)
- [Inference](#inference)
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
- [Transfer Learning](#transfer-learning)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Unet Description](#contents)
Unet Medical model for 2D image segmentation. This implementation is as described in the original paper [UNet: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597). Unet, in the 2015 ISBI cell tracking competition, many of the best are obtained. In this paper, a network model for medical image segmentation is proposed, and a data enhancement method is proposed to effectively use the annotation data to solve the problem of insufficient annotation data in the medical field. A U-shaped network structure is also used to extract the context and location information.
[Paper](https://arxiv.org/abs/1505.04597): Olaf Ronneberger, Philipp Fischer, Thomas Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation." * conditionally accepted at MICCAI 2015*. 2015.
# [Model Architecture](#contents)
Specifically, the U network structure is proposed in UNET, which can better extract and fuse high-level features and obtain context information and spatial location information. The U network structure is composed of encoder and decoder. The encoder is composed of two 3x3 conv and a 2x2 max pooling iteration. The number of channels is doubled after each down sampling. The decoder is composed of a 2x2 deconv, concat layer and two 3x3 convolutions, and then outputs after a 1x1 convolution.
# [Dataset](#contents)
Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home)
- Description: The training and test datasets are two stacks of 30 sections from a serial section Transmission Electron Microscopy (ssTEM) data set of the Drosophila first instar larva ventral nerve cord (VNC). The microcube measures 2 x 2 x 1.5 microns approx., with a resolution of 4x4x50 nm/pixel.
- License: You are free to use this data set for the purpose of generating or testing non-commercial image segmentation software. If any scientific publications derive from the usage of this data set, you must cite TrakEM2 and the following publication: Cardona A, Saalfeld S, Preibisch S, Schmid B, Cheng A, Pulokas J, Tomancak P, Hartenstein V. 2010. An Integrated Micro- and Macroarchitectural Analysis of the Drosophila Brain by Computer-Assisted Serial Section Electron Microscopy. PLoS Biol 8(10): e1000502. doi:10.1371/journal.pbio.1000502.
- Dataset size22.5M
- Train15M, 30 images (Training data contains 2 multi-page TIF files, each containing 30 2D-images. train-volume.tif and train-labels.tif respectly contain data and label.)
- Val(We randomly divde the training data into 5-fold and evaluate the model by across 5-fold cross-validation.)
- Test7.5M, 30 images (Testing data contains 1 multi-page TIF files, each containing 30 2D-images. test-volume.tif respectly contain data.)
- Data formatbinary files(TIF file)
- NoteData will be processed in src/data_loader.py
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- running on Ascend
```python
# run training example
python train.py --data_url=/path/to/data/ > train.log 2>&1 &
OR
bash scripts/run_standalone_train.sh [DATASET]
# run distributed training example
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
# run evaluation example
python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```
├── model_zoo
├── README.md // descriptions about all the models
├── unet
├── README.md // descriptions about Unet
├── scripts
│ ├──run_standalone_train.sh // shell script for distributed on Ascend
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
├── src
│ ├──config.py // parameter configuration
│ ├──data_loader.py // creating dataset
│ ├──loss.py // loss
│ ├──utils.py // General components (callback function)
│ ├──unet.py // Unet architecture
├──__init__.py // init file
├──unet_model.py // unet model
├──unet_parts.py // unet part
├── train.py // training script
├──launch_8p.py // training 8P script
├── eval.py // evaluation script
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- config for Unet, ISBI dataset
```python
'name': 'Unet', # model name
'lr': 0.0001, # learning rate
'epochs': 400, # total training epochs when run 1p
'distribute_epochs': 1600, # total training epochs when run 8p
'batchsize': 16, # training batch size
'cross_valid_ind': 1, # cross valid ind
'num_classes': 2, # the number of classes in the dataset
'num_channels': 1, # the number of channels
'keep_checkpoint_max': 10, # only keep the last keep_checkpoint_max checkpoint
'weight_decay': 0.0005, # weight decay value
'loss_scale': 1024.0, # loss scale
'FixedLossScaleManager': 1024.0, # fix loss scale
```
## [Training Process](#contents)
### Training
- running on Ascend
```
python train.py --data_url=/path/to/data/ > train.log 2>&1 &
OR
bash scripts/run_standalone_train.sh [DATASET]
```
The python command above will run in the background, you can view the results through the file `train.log`.
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
```
# grep "loss is " train.log
step: 1, loss is 0.7011719, fps is 0.25025035060906264
step: 2, loss is 0.69433594, fps is 56.77693756377044
step: 3, loss is 0.69189453, fps is 57.3293877244179
step: 4, loss is 0.6894531, fps is 57.840651522059716
step: 5, loss is 0.6850586, fps is 57.89903776054361
step: 6, loss is 0.6777344, fps is 58.08073627299014
...
step: 597, loss is 0.19030762, fps is 58.28088370287449
step: 598, loss is 0.19958496, fps is 57.95493929352674
step: 599, loss is 0.18371582, fps is 58.04039977720966
step: 600, loss is 0.22070312, fps is 56.99692546024671
```
The model checkpoint will be saved in the current directory.
### Distributed Training
```
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
```
The above shell script will run distribute training in the background. You can view the results through the file `logs/device[X]/log.log`. The loss value will be achieved as follows:
```
# grep "loss is" logs/device0/log.log
step: 1, loss is 0.70524895, fps is 0.15914689861221412
step: 2, loss is 0.6925452, fps is 56.43668656967454
...
step: 299, loss is 0.20551169, fps is 58.4039329983891
step: 300, loss is 0.18949677, fps is 57.63118508760329
```
## [Evaluation Process](#contents)
### Evaluation
- evaluation on ISBI dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet/ckpt_unet_medical_adam-48_600.ckpt".
```
python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```
# grep "Cross valid dice coeff is:" eval.log
============== Cross valid dice coeff is: {'dice_coeff': 0.9085704886070473}
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | Unet |
| Resource | Ascend 910 ;CPU 2.60GHz,56cores; Memory,314G |
| uploaded Date | 09/15/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | ISBI |
| Training Parameters | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 |
| | 8pc: epoch=1600, total steps=300, batch_size = 16, lr=0.0001 |
| Optimizer | ADAM |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 0.22070312 |
| Speed | 1pc: 267 ms/step; 8pc: 280 ms/step; |
| Total time | 1pc: 2.67 mins; 8pc: 1.40 mins |
| Parameters (M) | 93M |
| Checkpoint for Fine tuning | 355.11M (.ckpt file) |
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
## [How to use](#contents)
### Inference
If you need to use the trained model to perform inference on multiple hardware platforms, such as Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
- Running on Ascend
```
# Set context
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",save_graphs=True,device_id=device_id)
# Load unseen dataset for inference
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
# Define model and Load pre-trained model
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
param_dict= load_checkpoint(ckpt_path)
load_param_into_net(net , param_dict)
criterion = CrossEntropyWithLogits()
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# Make predictions on the unseen dataset
print("============== Starting Evaluating ============")
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
print("============== Cross valid dice coeff is:", dice_score)
```
### Transfer Learning
To be added.
# [Description of Random Situation](#contents)
In data_loader.py, we set the seed inside “_get_val_train_indices" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,123 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import logging
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.ops.operations as F
from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss.loss import _Loss
from src.data_loader import create_dataset
from src.unet import UNet
from src.config import cfg_unet
from scipy.special import softmax
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
class CrossEntropyWithLogits(_Loss):
def __init__(self):
super(CrossEntropyWithLogits, self).__init__()
self.transpose_fn = F.Transpose()
self.reshape_fn = F.Reshape()
self.softmax_cross_entropy_loss = nn.SoftmaxCrossEntropyWithLogits()
self.cast = F.Cast()
def construct(self, logits, label):
# NCHW->NHWC
logits = self.transpose_fn(logits, (0, 2, 3, 1))
logits = self.cast(logits, mindspore.float32)
label = self.transpose_fn(label, (0, 2, 3, 1))
loss = self.reduce_mean(self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)),
self.reshape_fn(label, (-1, 2))))
return self.get_loss(loss)
class dice_coeff(nn.Metric):
def __init__(self):
super(dice_coeff, self).__init__()
self.clear()
def clear(self):
self._dice_coeff_sum = 0
self._samples_num = 0
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Mean dice coeffcient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
y_pred = y_pred.transpose(0, 2, 3, 1)
y = y.transpose(0, 2, 3, 1)
y_pred = softmax(y_pred, axis=3)
inter = np.dot(y_pred.flatten(), y.flatten())
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
single_dice_coeff = 2*float(inter)/float(union+1e-6)
print("single dice coeff is:", single_dice_coeff)
self._dice_coeff_sum += single_dice_coeff
def eval(self):
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)
def test_net(data_dir,
ckpt_path,
cross_valid_ind=1,
cfg=None):
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
criterion = CrossEntropyWithLogits()
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()})
print("============== Starting Evaluating ============")
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
print("============== Cross valid dice coeff is:", dice_score)
def get_args():
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='data directory')
parser.add_argument('-p', '--ckpt_path', dest='ckpt_path', type=str, default='ckpt_unet_medical_adam-1_600.ckpt',
help='checkpoint path')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
print("Testing setting:", args)
test_net(data_dir=args.data_url,
ckpt_path=args.ckpt_path,
cross_valid_ind=cfg_unet['cross_valid_ind'],
cfg=cfg_unet)

View File

@ -0,0 +1,50 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data"
echo "=============================================================================================================="
if [ $# != 2 ]
then
echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
exit 1
fi
export RANK_SIZE=8
for((i=0;i<RANK_SIZE;i++))
do
rm -rf LOG$i
mkdir ./LOG$i
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cd ./LOG$i || exit
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export RANK_ID=$i
export DEVICE_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
python3 train.py \
--run_distribute=True \
--data_url=$2 > log.txt 2>&1 &
cd ../
done

View File

@ -0,0 +1,24 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "=============================================================================================================="
export DEVICE_ID=0
python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 &

View File

@ -0,0 +1,24 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train.sh [DATASET]"
echo "for example: bash run_standalone_train.sh /path/to/data/"
echo "=============================================================================================================="
export DEVICE_ID=0
python train.py --data_url=$1 > train.log 2>&1 &

View File

@ -0,0 +1,30 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cfg_unet = {
'name': 'Unet',
'lr': 0.0001,
'epochs': 400,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 1,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
}

View File

@ -0,0 +1,159 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from collections import deque
import numpy as np
from PIL import Image, ImageSequence
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore.dataset.vision.utils import Inter
from mindspore.communication.management import get_rank, get_group_size
def _load_multipage_tiff(path):
"""Load tiff images containing many images in the channel dimension"""
return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))])
def _get_val_train_indices(length, fold, ratio=0.8):
assert 0 < ratio <= 1, "Train/total data ratio must be in range (0.0, 1.0]"
np.random.seed(0)
indices = np.arange(0, length, 1, dtype=np.int)
np.random.shuffle(indices)
if fold is not None:
indices = deque(indices)
indices.rotate(fold * round((1.0 - ratio) * length))
indices = np.array(indices)
train_indices = indices[:round(ratio * len(indices))]
val_indices = indices[round(ratio * len(indices)):]
else:
train_indices = indices
val_indices = []
return train_indices, val_indices
def data_post_process(img, mask):
img = np.expand_dims(img, axis=0)
mask = (mask > 0.5).astype(np.int)
mask = (np.arange(mask.max() + 1) == mask[..., None]).astype(int)
mask = mask.transpose(2, 0, 1).astype(np.float32)
return img, mask
def train_data_augmentation(img, mask):
h_flip = np.random.random()
if h_flip > 0.5:
img = np.flipud(img)
mask = np.flipud(mask)
v_flip = np.random.random()
if v_flip > 0.5:
img = np.fliplr(img)
mask = np.fliplr(mask)
left = int(np.random.uniform()*0.3*572)
right = int((1-np.random.uniform()*0.3)*572)
top = int(np.random.uniform()*0.3*572)
bottom = int((1-np.random.uniform()*0.3)*572)
img = img[top:bottom, left:right]
mask = mask[top:bottom, left:right]
#adjust brightness
brightness = np.random.uniform(-0.2, 0.2)
img = np.float32(img+brightness*np.ones(img.shape))
img = np.clip(img, -1.0, 1.0)
return img, mask
def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False):
images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif'))
masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif'))
train_indices, val_indices = _get_val_train_indices(len(images), cross_val_ind)
train_images = images[train_indices]
train_masks = masks[train_indices]
train_images = np.repeat(train_images, repeat, axis=0)
train_masks = np.repeat(train_masks, repeat, axis=0)
val_images = images[val_indices]
val_masks = masks[val_indices]
train_image_data = {"image": train_images}
train_mask_data = {"mask": train_masks}
valid_image_data = {"image": val_images}
valid_mask_data = {"mask": val_masks}
ds_train_images = ds.NumpySlicesDataset(data=train_image_data, sampler=None, shuffle=False)
ds_train_masks = ds.NumpySlicesDataset(data=train_mask_data, sampler=None, shuffle=False)
if run_distribute:
rank_id = get_rank()
rank_size = get_group_size()
ds_train_images = ds.NumpySlicesDataset(data=train_image_data,
sampler=None,
shuffle=False,
num_shards=rank_size,
shard_id=rank_id)
ds_train_masks = ds.NumpySlicesDataset(data=train_mask_data,
sampler=None,
shuffle=False,
num_shards=rank_size,
shard_id=rank_id)
ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False)
ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False)
c_resize_op = c_vision.Resize(size=(388, 388), interpolation=Inter.BILINEAR)
c_pad = c_vision.Pad(padding=92)
c_rescale_image = c_vision.Rescale(1.0/127.5, -1)
c_rescale_mask = c_vision.Rescale(1.0/255.0, 0)
c_trans_normalize_img = [c_rescale_image, c_resize_op, c_pad]
c_trans_normalize_mask = [c_rescale_mask, c_resize_op, c_pad]
c_center_crop = c_vision.CenterCrop(size=388)
train_image_ds = ds_train_images.map(input_columns="image", operations=c_trans_normalize_img)
train_mask_ds = ds_train_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
train_ds = ds.zip((train_image_ds, train_mask_ds))
train_ds = train_ds.project(columns=["image", "mask"])
if augment:
augment_process = train_data_augmentation
c_resize_op = c_vision.Resize(size=(572, 572), interpolation=Inter.BILINEAR)
train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process)
train_ds = train_ds.map(input_columns="image", operations=c_resize_op)
train_ds = train_ds.map(input_columns="mask", operations=c_resize_op)
train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process
train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process)
train_ds = train_ds.shuffle(repeat*24)
train_ds = train_ds.batch(batch_size=train_batch_size, drop_remainder=True)
valid_image_ds = ds_valid_images.map(input_columns="image", operations=c_trans_normalize_img)
valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
valid_ds = ds.zip((valid_image_ds, valid_mask_ds))
valid_ds = valid_ds.project(columns=["image", "mask"])
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process
valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process)
valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True)
return train_ds, valid_ds

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore
import mindspore.nn as nn
import mindspore.ops.operations as F
from mindspore.nn.loss.loss import _Loss
class CrossEntropyWithLogits(_Loss):
def __init__(self):
super(CrossEntropyWithLogits, self).__init__()
self.transpose_fn = F.Transpose()
self.reshape_fn = F.Reshape()
self.softmax_cross_entropy_loss = nn.SoftmaxCrossEntropyWithLogits()
self.cast = F.Cast()
def construct(self, logits, label):
# NCHW->NHWC
logits = self.transpose_fn(logits, (0, 2, 3, 1))
logits = self.cast(logits, mindspore.float32)
label = self.transpose_fn(label, (0, 2, 3, 1))
loss = self.reduce_mean(
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2))))
return self.get_loss(loss)

View File

@ -0,0 +1,16 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .unet_model import UNet

View File

@ -0,0 +1,47 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from src.unet.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv
import mindspore.nn as nn
class UNet(nn.Cell):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
self.up1 = Up1(1024, 512)
self.up2 = Up2(512, 256)
self.up3 = Up3(256, 128)
self.up4 = Up4(128, 64)
self.outc = OutConv(64, n_classes)
def construct(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits

View File

@ -0,0 +1,150 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Parts of the U-Net model """
import mindspore.nn as nn
import mindspore.ops.operations as F
from mindspore.common.initializer import TruncatedNormal
from mindspore.nn import CentralCrop
class DoubleConv(nn.Cell):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
init_value_0 = TruncatedNormal(0.06)
init_value_1 = TruncatedNormal(0.06)
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.SequentialCell(
[nn.Conv2d(in_channels, mid_channels, kernel_size=3, has_bias=True,
weight_init=init_value_0, pad_mode="valid"),
nn.ReLU(),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, has_bias=True,
weight_init=init_value_1, pad_mode="valid"),
nn.ReLU()]
)
def construct(self, x):
return self.double_conv(x)
class Down(nn.Cell):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.SequentialCell(
[nn.MaxPool2d(kernel_size=2, stride=2),
DoubleConv(in_channels, out_channels)]
)
def construct(self, x):
return self.maxpool_conv(x)
class Up1(nn.Cell):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.concat = F.Concat(axis=1)
self.factor = 56.0 / 64.0
self.center_crop = CentralCrop(central_fraction=self.factor)
self.print_fn = F.Print()
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.relu = nn.ReLU()
def construct(self, x1, x2):
x1 = self.up(x1)
x1 = self.relu(x1)
x2 = self.center_crop(x2)
x = self.concat((x1, x2))
return self.conv(x)
class Up2(nn.Cell):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.concat = F.Concat(axis=1)
self.factor = 104.0 / 136.0
self.center_crop = CentralCrop(central_fraction=self.factor)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.relu = nn.ReLU()
def construct(self, x1, x2):
x1 = self.up(x1)
x1 = self.relu(x1)
x2 = self.center_crop(x2)
x = self.concat((x1, x2))
return self.conv(x)
class Up3(nn.Cell):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.concat = F.Concat(axis=1)
self.factor = 200 / 280
self.center_crop = CentralCrop(central_fraction=self.factor)
self.print_fn = F.Print()
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.relu = nn.ReLU()
def construct(self, x1, x2):
x1 = self.up(x1)
x1 = self.relu(x1)
x2 = self.center_crop(x2)
x = self.concat((x1, x2))
return self.conv(x)
class Up4(nn.Cell):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.concat = F.Concat(axis=1)
self.factor = 392 / 568
self.center_crop = CentralCrop(central_fraction=self.factor)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.relu = nn.ReLU()
def construct(self, x1, x2):
x1 = self.up(x1)
x1 = self.relu(x1)
x2 = self.center_crop(x2)
x = self.concat((x1, x2))
return self.conv(x)
class OutConv(nn.Cell):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
init_value = TruncatedNormal(0.06)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=True, weight_init=init_value)
def construct(self, x):
x = self.conv(x)
return x

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import numpy as np
from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor
class StepLossTimeMonitor(Callback):
def __init__(self, batch_size, per_print_times=1):
super(StepLossTimeMonitor, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.batch_size = batch_size
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_seconds = time.time() - self.step_time
step_fps = self.batch_size*1.0/step_seconds
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
# TEST
print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True)
# print("step: %s, loss is %s, fps is %s" % ( cur_step_in_epoch, loss, step_fps))

View File

@ -0,0 +1,106 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import logging
import ast
import mindspore
import mindspore.nn as nn
from mindspore import Model, context
from mindspore.communication.management import init, get_group_size
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.context import ParallelMode
from src.unet import UNet
from src.data_loader import create_dataset
from src.loss import CrossEntropyWithLogits
from src.utils import StepLossTimeMonitor
from src.config import cfg_unet
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
mindspore.set_seed(1)
def train_net(data_dir,
cross_valid_ind=1,
epochs=400,
batch_size=16,
lr=0.0001,
run_distribute=False,
cfg=None):
if run_distribute:
init()
group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size,
parameter_broadcast=True,
gradients_mean=False)
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
criterion = CrossEntropyWithLogits()
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
keep_checkpoint_max=cfg['keep_checkpoint_max'])
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
directory='./ckpt_{}/'.format(device_id),
config=ckpt_config)
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
loss_scale=cfg['loss_scale'])
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
print("============== Starting Training ==============")
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
dataset_sink_mode=False)
print("============== End Training ==============")
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='data directory')
parser.add_argument('-t', '--run_distribute', type=ast.literal_eval,
default=False, help='Run distribute, default: false.')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
print("Training setting:", args)
epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs']
train_net(data_dir=args.data_url,
cross_valid_ind=cfg_unet['cross_valid_ind'],
epochs=epoch_size,
batch_size=cfg_unet['batchsize'],
lr=cfg_unet['lr'],
run_distribute=args.run_distribute,
cfg=cfg_unet)