From: @Somnus2020
Reviewed-by: @linqingke,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-03-25 15:40:19 +08:00 committed by Gitee
commit f8198ece2e
16 changed files with 1605 additions and 0 deletions

View File

@ -0,0 +1,263 @@
# Contents
- [Contents](#contents)
- [Unet Description](#unet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [running on Ascend](#running-on-ascend)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#inference-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
## [Unet Description](#contents)
Unet3D model is widely used for 3D medical image segmentation. The construct of Unet3D network is similar to the Unet, the main difference is that Unet3D use 3D operations like Conv3D while Unet is anentirely 2D architecture. To know more information about Unet3D network, you can read the original paper Unet3D: Learning Dense VolumetricSegmentation from Sparse Annotation.
## [Model Architecture](#contents)
Unet3D model is created based on the previous Unet(2D), which includes an encoder part and a decoder part. The encoder part is used to analyze the whole picture and extract and analyze features, while the decoder part is to generate a segmented block image. In this model, we also add residual block in the base block to improve the network.
## [Dataset](#contents)
Dataset used: [LUNA16](https://luna16.grand-challenge.org/)
- Description: The data is to automatically detect the location of nodules from volumetric CT images. 888 CT scans from LIDC-IDRI database are provided. The complete dataset is divided into 10 subsets that should be used for the 10-fold cross-validation. All subsets are available as compressed zip files.
- Dataset size888
- Train878 images
- Test10 images
- Data formatzip
- NoteData will be processed in convert_nifti.py
## [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- Select the network and dataset to use
```shell
Convert dataset into mifti format.
python ./src/convert_nifti.py --input_path=/path/to/input_image/ --output_path=/path/to/output_image/
```
Refer to `src/config.py`. We support some parameter configurations for quick start.
- Run on Ascend
```python
# run training example
python train.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ > train.log 2>&1 &
# run distributed training example
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]
# run evaluation example
python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
```text
.
└─unet3d
├── README.md // descriptions about Unet3D
├── scripts
│ ├──run_disribute_train.sh // shell script for distributed on Ascend
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
├── src
│ ├──config.py // parameter configuration
│ ├──dataset.py // creating dataset
│   ├──lr_schedule.py // learning rate scheduler
│ ├──transform.py // handle dataset
│ ├──convert_nifti.py // convert dataset
│ ├──loss.py // loss
│ ├──conv.py // conv components
│ ├──utils.py // General components (callback function)
│ ├──unet3d_model.py // Unet3D model
│ ├──unet3d_parts.py // Unet3D part
├── train.py // training script
├── eval.py // evaluation script
```
### [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- config for Unet3d, luna16 dataset
```python
'model': 'Unet3d', # model name
'lr': 0.0005, # learning rate
'epochs': 10, # total training epochs when run 1p
'batchsize': 1, # training batch size
"warmup_step": 120, # warmp up step in lr generator
"warmup_ratio": 0.3, # warpm up ratio
'num_classes': 4, # the number of classes in the dataset
'in_channels': 1, # the number of channels
'keep_checkpoint_max': 5, # only keep the last keep_checkpoint_max checkpoint
'loss_scale': 256.0, # loss scale
'roi_size': [224, 224, 96], # random roi size
'overlap': 0.25, # overlap rate
'min_val': -500, # intersity original range min
'max_val': 1000, # intersity original range max
'upper_limit': 5 # upper limit of num_classes
'lower_limit': 3 # lower limit of num_classes
```
## [Training Process](#contents)
### Training
#### running on Ascend
```shell
python train.py --data_url=/path/to/data/ -seg_url=/path/to/segment/ > train.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `train.log`.
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
```shell
epoch: 1 step: 878, loss is 0.55011123
epoch time: 1443410.353 ms, per step time: 1688.199 ms
epoch: 2 step: 878, loss is 0.58278626
epoch time: 1172136.839 ms, per step time: 1370.920 ms
epoch: 3 step: 878, loss is 0.43625978
epoch time: 1135890.834 ms, per step time: 1328.537 ms
epoch: 4 step: 878, loss is 0.06556784
epoch time: 1180467.795 ms, per step time: 1380.664 ms
```
#### Distributed Training
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
>
```shell
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]
```
The above shell script will run distribute training in the background. You can view the results through the file `/train_parallel[X]/log.txt`. The loss value will be achieved as follows:
```shell
epoch: 1 step: 110, loss is 0.8294426
epoch time: 468891.643 ms, per step time: 4382.165 ms
epoch: 2 step: 110, loss is 0.58278626
epoch time: 165469.201 ms, per step time: 1546.441 ms
epoch: 3 step: 110, loss is 0.43625978
epoch time: 158915.771 ms, per step time: 1485.194 ms
...
epoch: 9 step: 110, loss is 0.016280059
epoch time: 172815.179 ms, per step time: 1615.095 ms
epoch: 10 step: 110, loss is 0.020185348
epoch time: 140476.520 ms, per step time: 1312.865 ms
```
## [Evaluation Process](#contents)
### Evaluation
- evaluation on dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet3d/Unet3d-10_110.ckpt".
```shell
python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```shell
# grep "eval average dice is:" eval.log
eval average dice is 0.9502010010453671
```
## [Model Description](#contents)
### [Performance](#contents)
#### Evaluation Performance
| Parameters | Ascend |
| ------------------- | --------------------------------------------------------- |
| Model Version | Unet3D |
| Resource | Ascend 910; CPU 2.60GHz192coresMemory755G |
| uploaded Date | 03/18/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | LUNA16 |
| Training Parameters | epoch = 10, batch_size = 1 |
| Optimizer | Adam |
| Loss Function | SoftmaxCrossEntropyWithLogits |
| Speed | 8pcs: 1795ms/step |
| Total time | 8pcs: 0.62hours |
| Parameters (M) | 34 |
| Scripts | [unet3d script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet3d) |
#### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | Unet3D |
| Resource | Ascend 910 |
| Uploaded Date | 03/18/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | LUNA16 |
| batch_size | 1 |
| Dice | dice = 0.9502 |
| Model for inference | 56M(.ckpt file) |
# [Description of Random Situation](#contents)
We set seed to 1 in train.py.
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,78 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import numpy as np
from mindspore import dtype as mstype
from mindspore import Model, context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.unet3d_model import UNet3d
from src.config import config as cfg
from src.utils import create_sliding_window, CalculateDice
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
def get_args():
parser = argparse.ArgumentParser(description='Test the UNet3D on images and target masks')
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
parser.add_argument('--ckpt_path', dest='ckpt_path', type=str, default='', help='checkpoint path')
return parser.parse_args()
def test_net(data_dir, seg_dir, ckpt_path, config=None):
eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, is_training=False)
eval_data_size = eval_dataset.get_dataset_size()
print("train dataset length is:", eval_data_size)
network = UNet3d(config=config)
network.set_train(False)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
model = Model(network)
index = 0
total_dice = 0
for batch in eval_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
image = batch["image"]
seg = batch["seg"]
print("current image shape is {}".format(image.shape), flush=True)
sliding_window_list, slice_list = create_sliding_window(image, config.roi_size, config.overlap)
image_size = (config.batch_size, config.num_classes) + image.shape[2:]
output_image = np.zeros(image_size, np.float32)
count_map = np.zeros(image_size, np.float32)
importance_map = np.ones(config.roi_size, np.float32)
for window, slice_ in zip(sliding_window_list, slice_list):
window_image = Tensor(window, mstype.float32)
pred_probs = model.predict(window_image)
output_image[slice_] += pred_probs.asnumpy()
count_map[slice_] += importance_map
output_image = output_image / count_map
dice, _ = CalculateDice(output_image, seg)
print("The {} batch dice is {}".format(index, dice), flush=True)
total_dice += dice
index = index + 1
avg_dice = total_dice / eval_data_size
print("**********************End Eval***************************************")
print("eval average dice is {}".format(avg_dice))
if __name__ == '__main__':
args = get_args()
print("Testing setting:", args)
test_net(data_dir=args.data_url,
seg_dir=args.seg_url,
ckpt_path=args.ckpt_path,
config=cfg)

View File

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 3 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
PATH2=$(get_real_path $2)
echo $PATH2
if [ ! -d $PATH2 ]
then
echo "error: IMAGE_PATH=$PATH2 is not a file"
exit 1
fi
PATH3=$(get_real_path $3)
echo $PATH3
if [ ! -d $PATH3 ]
then
echo "error: SEG_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py \
--run_distribute=True \
--data_url=$PATH2 \
--seg_url=$PATH3 > log.txt 2>&1 &
cd ../
done

View File

@ -0,0 +1,82 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "=============================================================================================================="
fi
if [ $# != 3 ]
then
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
IMAGE_PATH=$(get_real_path $1)
SEG_PATH=$(get_real_path $2)
CHECKPOINT_FILE_PATH=$(get_real_path $3)
echo $IMAGE_PATH
echo $SEG_PATH
echo $CHECKPOINT_FILE_PATH
if [ ! -d $IMAGE_PATH ]
then
echo "error: IMAGE_PATH=$IMAGE_PATH is not a path"
exit 1
fi
if [ ! -d $SEG_PATH ]
then
echo "error: SEG_PATH=$SEG_PATH is not a path"
exit 1
fi
if [ ! -f $CHECKPOINT_FILE_PATH ]
then
echo "error: CHECKPOINT_FILE_PATH=$CHECKPOINT_FILE_PATH is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
python eval.py --data_url=$IMAGE_PATH --seg_url=$SEG_PATH --ckpt_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 &
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
cd ..

View File

@ -0,0 +1,62 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -d $PATH1 ]
then
echo "error: IMAGE_PATH=$PATH1 is not a file"
exit 1
fi
PATH2=$(get_real_path $2)
echo $PATH2
if [ ! -d $PATH2 ]
then
echo "error: SEG_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
rm -rf ./train
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --data_url=$PATH1 --seg_url=$PATH2 > train.log 2>&1 &
cd ..

View File

@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from easydict import EasyDict
config = EasyDict({
'model': 'Unet3d',
'lr': 0.0005,
'epoch_size': 10,
'batch_size': 1,
'warmup_step': 120,
'warmup_ratio': 0.3,
'num_classes': 4,
'in_channels': 1,
'keep_checkpoint_max': 5,
'loss_scale': 256.0,
'roi_size': [224, 224, 96],
'overlap': 0.25,
'min_val': -500,
'max_val': 1000,
'upper_limit': 5,
'lower_limit': 3,
})

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import Parameter
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops.operations import nn_ops as nps
from mindspore.common.initializer import initializer
def weight_variable(shape):
init_value = initializer('Normal', shape, mstype.float32)
return Parameter(init_value)
class Conv3D(nn.Cell):
def __init__(self,
in_channel,
out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
stride=1,
dilation=1,
group=1,
data_format="NCDHW",
bias_init="zeros",
has_bias=True):
super().__init__()
self.weight_shape = (out_channel, in_channel, kernel_size[0], kernel_size[1], kernel_size[2])
self.weight = weight_variable(self.weight_shape)
self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, \
pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation, \
group=group, data_format=data_format)
self.bias_init = bias_init
self.has_bias = has_bias
self.bias_add = P.BiasAdd(data_format=data_format)
if self.has_bias:
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
def construct(self, x):
output = self.conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class Conv3DTranspose(nn.Cell):
def __init__(self,
in_channel,
out_channel,
kernel_size,
mode=1,
pad=0,
stride=1,
dilation=1,
group=1,
output_padding=0,
data_format="NCDHW",
bias_init="zeros",
has_bias=True):
super().__init__()
self.weight_shape = (in_channel, out_channel, kernel_size[0], kernel_size[1], kernel_size[2])
self.weight = weight_variable(self.weight_shape)
self.conv_transpose = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel,\
kernel_size=kernel_size, mode=mode, pad=pad, stride=stride, \
dilation=dilation, group=group, output_padding=output_padding, \
data_format=data_format)
self.bias_init = bias_init
self.has_bias = has_bias
self.bias_add = P.BiasAdd(data_format=data_format)
if self.has_bias:
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
def construct(self, x):
output = self.conv_transpose(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output

View File

@ -0,0 +1,62 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
from pathlib import Path
import SimpleITK as sitk
from src.config import config
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, help="Input image directory to be processed.")
parser.add_argument("--output_path", type=str, help="Output file path.")
args = parser.parse_args()
def get_list_of_files_in_dir(directory, file_types='*'):
"""
Get list of certain format files.
Args:
directory (str): The input directory for image.
file_types (str): The file_types to filter the files.
"""
return [f for f in Path(directory).glob(file_types) if f.is_file()]
def convert_nifti(input_dir, output_dir, roi_size, file_types):
"""
Convert dataset into mifti format.
Args:
input_dir (str): The input directory for image.
output_dir (str): The output directory to save nifti format data.
roi_size (str): The size to crop the image.
file_types: File types to convert into nifti.
"""
file_list = get_list_of_files_in_dir(input_dir, file_types)
for file_name in file_list:
file_name = str(file_name)
input_file_name, _ = os.path.splitext(os.path.basename(file_name))
img = sitk.ReadImage(file_name)
image_array = sitk.GetArrayFromImage(img)
D, H, W = image_array.shape
if H < roi_size[0] or W < roi_size[1] or D < roi_size[2]:
print("file {} size is smaller than roi size, ignore it.".format(input_file_name))
continue
output_path = os.path.join(output_dir, input_file_name + ".nii.gz")
sitk.WriteImage(img, output_path)
print("create output file {} success.".format(output_path))
if __name__ == '__main__':
convert_nifti(args.input_path, args.output_path, config.roi_size, "*.mhd")

View File

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
import os
import glob
import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.transforms.py_transforms import Compose
from src.config import config as cfg
from src.transform import Dataset, ExpandChannel, LoadData, Orientation, ScaleIntensityRange, RandomCropSamples, OneHot
class ConvertLabel:
"""
Crop at the center of image with specified ROI size.
Args:
roi_size: the spatial size of the crop region e.g. [224,224,128]
If its components have non-positive values, the corresponding size of input image will be used.
"""
def operation(self, data):
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
"""
data[data > cfg['upper_limit']] = 0
data = data - (cfg['lower_limit'] - 1)
data = np.clip(data, 0, cfg['lower_limit'])
return data
def __call__(self, image, label):
label = self.operation(label)
return image, label
def create_dataset(data_path, seg_path, config, rank_size=1, rank_id=0, is_training=True):
seg_files = sorted(glob.glob(os.path.join(seg_path, "*.nii.gz")))
train_files = [os.path.join(data_path, os.path.basename(seg)) for seg in seg_files]
train_ds = Dataset(data=train_files, seg=seg_files)
train_loader = ds.GeneratorDataset(train_ds, column_names=["image", "seg"], num_parallel_workers=4, \
shuffle=is_training, num_shards=rank_size, shard_id=rank_id)
if is_training:
transform_image = Compose([LoadData(),
ExpandChannel(),
Orientation(),
ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
tgt_max=1.0, is_clip=True),
RandomCropSamples(roi_size=config.roi_size, num_samples=2),
ConvertLabel(),
OneHot(num_classes=config.num_classes)])
else:
transform_image = Compose([LoadData(),
ExpandChannel(),
Orientation(),
ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
tgt_max=1.0, is_clip=True),
ConvertLabel()])
train_loader = train_loader.map(operations=transform_image, input_columns=["image", "seg"], num_parallel_workers=12,
python_multiprocessing=True)
if not is_training:
train_loader = train_loader.batch(1)
return train_loader

View File

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss
from src.config import config
class SoftmaxCrossEntropyWithLogits(_Loss):
def __init__(self):
super(SoftmaxCrossEntropyWithLogits, self).__init__()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
self.cast = P.Cast()
self.reduce_mean = P.ReduceMean()
def construct(self, logits, label):
logits = self.transpose(logits, (0, 2, 3, 4, 1))
label = self.transpose(label, (0, 2, 3, 4, 1))
label = self.cast(label, mstype.float32)
loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, config['num_classes'])), \
self.reshape(label, (-1, config['num_classes']))))
return self.get_loss(loss)

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, base_step):
"""dynamic learning rate generator"""
base_lr = config.lr
total_steps = int(base_step * config.epoch_size)
warmup_steps = config.warmup_step
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

View File

@ -0,0 +1,266 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
import re
import numpy as np
from mindspore import log as logger
import nibabel as nib
from src.utils import correct_nifti_head, get_random_patch
MAX_SEED = np.iinfo(np.uint32).max + 1
np_str_obj_array_pattern = re.compile(r'[SaUO]')
class Dataset:
"""
A generic dataset with a length property and an optional callable data transform
when fetching a data sample.
Args:
data: input data to load and transform to generate dataset for model.
seg: segment data to load and transform to generate dataset for model
"""
def __init__(self, data, seg):
self.data = data
self.seg = seg
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data = self.data[index]
seg = self.seg[index]
return [data], [seg]
class LoadData:
"""
Load Image data from provided files.
"""
def __init__(self, canonical=False, dtype=np.float32):
"""
Args:
canonical: if True, load the image as closest to canonical axis format.
dtype: convert the loaded image to this data type.
"""
self.canonical = canonical
self.dtype = dtype
def operation(self, filename):
"""
Args:
filename: path file or file-like object or a list of files.
"""
img_array = list()
compatible_meta = dict()
filename = str(filename, encoding="utf-8")
filename = [filename]
for name in filename:
img = nib.load(name)
img = correct_nifti_head(img)
header = dict(img.header)
header["filename_or_obj"] = name
header["affine"] = img.affine
header["original_affine"] = img.affine.copy()
header["canonical"] = self.canonical
ndim = img.header["dim"][0]
spatial_rank = min(ndim, 3)
header["spatial_shape"] = img.header["dim"][1 : spatial_rank + 1]
if self.canonical:
img = nib.as_closest_canonical(img)
header["affine"] = img.affine
img_array.append(np.array(img.get_fdata(dtype=self.dtype)))
img.uncache()
if not compatible_meta:
for meta_key in header:
meta_datum = header[meta_key]
if isinstance(meta_datum, np.ndarray) \
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
continue
compatible_meta[meta_key] = meta_datum
else:
assert np.allclose(header["affine"], compatible_meta["affine"])
img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array
def __call__(self, filename1, filename2):
img_array = self.operation(filename1)
seg_array = self.operation(filename2)
return img_array, seg_array
class ExpandChannel:
"""
Expand a 1-length channel dimension to the input image.
"""
def operation(self, data):
"""
Args:
data(numpy.array): input data to expand channel.
"""
return data[None]
def __call__(self, img, label):
img_array = self.operation(img)
seg_array = self.operation(label)
return img_array, seg_array
class Orientation:
"""
Change the input image's orientation into the specified based on `ax`.
"""
def __init__(self, ax="RAS", labels=tuple(zip("LPI", "RAS"))):
"""
Args:
ax: N elements sequence for ND input's orientation.
labels: optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
"""
self.ax = ax
self.labels = labels
def operation(self, data, affine=None):
"""
original orientation of `data` is defined by `affine`.
Args:
data: in shape (num_channels, H[, W, ...]).
affine (matrix): (N+1)x(N+1) original affine matrix for spatially ND `data`. Defaults to identity.
Returns:
data (reoriented in `self.ax`), original ax, current ax.
"""
if data.ndim <= 1:
raise ValueError("data must have at least one spatial dimension.")
if affine is None:
affine = np.eye(data.ndim, dtype=np.float64)
affine_copy = affine
else:
affine_copy = to_affine_nd(data.ndim-1, affine)
src = nib.io_orientation(affine_copy)
dst = nib.orientations.axcodes2ornt(self.ax[:data.ndim-1], labels=self.labels)
spatial_ornt = nib.orientations.ornt_transform(src, dst)
ornt = spatial_ornt.copy()
ornt[:, 0] += 1
ornt = np.concatenate([np.array([[0, 1]]), ornt])
data = nib.orientations.apply_orientation(data, ornt)
return data
def __call__(self, img, label):
img_array = self.operation(img)
seg_array = self.operation(label)
return img_array, seg_array
class ScaleIntensityRange:
"""
Apply specific intensity scaling to the whole numpy array.
Scaling from [src_min, src_max] to [tgt_min, tgt_max] with clip option.
Args:
src_min: intensity original range min.
src_max: intensity original range max.
tgt_min: intensity target range min.
tgt_max: intensity target range max.
is_clip: whether to clip after scaling.
"""
def __init__(self, src_min, src_max, tgt_min, tgt_max, is_clip=False):
self.src_min = src_min
self.src_max = src_max
self.tgt_min = tgt_min
self.tgt_max = tgt_max
self.is_clip = is_clip
def operation(self, data):
if self.src_max - self.src_min == 0.0:
logger.warning("Divide by zero (src_min == src_max)")
return data - self.src_min + self.tgt_min
data = (data - self.src_min) / (self.src_max - self.src_min)
data = data * (self.tgt_max - self.tgt_min) + self.tgt_min
if self.is_clip:
data = np.clip(data, self.tgt_min, self.tgt_max)
return data
def __call__(self, image, label):
image = self.operation(image)
return image, label
class RandomCropSamples:
"""
Random crop 3d image.
Args:
keys: keys of the corresponding items to be transformed.
roi_size: if `random_size` is True, it specifies the minimum crop region.
num_samples: the amount of crop images.
"""
def __init__(self, roi_size, num_samples=1):
self.roi_size = roi_size
self.num_samples = num_samples
self.set_random_state(0)
def set_random_state(self, seed=None):
"""
Set the random seed to control the slice size.
Args:
seed: set the random state with an integer seed.
"""
if seed is not None:
_seed = seed % MAX_SEED
self.rand_fn = np.random.RandomState(_seed)
else:
self.rand_fn = np.random.RandomState()
return self
def get_random_slice(self, img_size):
slices = (slice(None),) + get_random_patch(img_size, self.roi_size, self.rand_fn)
return slices
def __call__(self, image, label):
res_image = []
res_label = []
for _ in range(self.num_samples):
slices = self.get_random_slice(image.shape[1:])
img = image[slices]
label_crop = label[slices]
res_image.append(img)
res_label.append(label_crop)
return np.array(res_image), np.array(res_label)
class OneHot:
def __init__(self, num_classes):
self.num_classes = num_classes
def one_hot(self, labels):
N, K = labels.shape
one_hot_encoding = np.zeros((N, self.num_classes, K), dtype=np.float32)
for i in range(N):
for j in range(K):
one_hot_encoding[i, labels[i][j], j] = 1
return one_hot_encoding
def operation(self, labels):
N, _, D, H, W = labels.shape
labels = labels.astype(np.int32)
labels = np.reshape(labels, (N, -1))
labels = self.one_hot(labels)
labels = np.reshape(labels, (N, self.num_classes, D, H, W))
return labels
def __call__(self, image, label):
label = self.operation(label)
return image, label

View File

@ -0,0 +1,62 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from src.unet3d_parts import Down, Up
class UNet3d(nn.Cell):
def __init__(self, config=None):
super(UNet3d, self).__init__()
self.n_channels = config.in_channels
self.n_classes = config.num_classes
# down
self.transpose = P.Transpose()
self.down1 = Down(in_channel=self.n_channels, out_channel=16, dtype=mstype.float16).to_float(mstype.float16)
self.down2 = Down(in_channel=16, out_channel=32, dtype=mstype.float16).to_float(mstype.float16)
self.down3 = Down(in_channel=32, out_channel=64, dtype=mstype.float16).to_float(mstype.float16)
self.down4 = Down(in_channel=64, out_channel=128, dtype=mstype.float16).to_float(mstype.float16)
self.down5 = Down(in_channel=128, out_channel=256, stride=1, kernel_size=(1, 1, 1), \
dtype=mstype.float16).to_float(mstype.float16)
# up
self.up1 = Up(in_channel=256, down_in_channel=128, out_channel=64, \
dtype=mstype.float16).to_float(mstype.float16)
self.up2 = Up(in_channel=64, down_in_channel=64, out_channel=32, \
dtype=mstype.float16).to_float(mstype.float16)
self.up3 = Up(in_channel=32, down_in_channel=32, out_channel=16, \
dtype=mstype.float16).to_float(mstype.float16)
self.up4 = Up(in_channel=16, down_in_channel=16, out_channel=self.n_classes, \
dtype=mstype.float16, is_output=True).to_float(mstype.float16)
self.cast = P.Cast()
def construct(self, input_data):
input_data = self.cast(input_data, mstype.float16)
x1 = self.down1(input_data)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.cast(x, mstype.float32)
return x

View File

@ -0,0 +1,112 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from src.conv import Conv3D, Conv3DTranspose
class BatchNorm3d(nn.Cell):
def __init__(self, num_features):
super().__init__()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.bn2d = nn.BatchNorm2d(num_features, data_format="NCHW")
def construct(self, x):
x_shape = self.shape(x)
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(x)
bn3d_out = self.reshape(bn2d_out, x_shape)
return bn3d_out
class ResidualUnit(nn.Cell):
def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), down=True, is_output=False):
super().__init__()
self.stride = stride
self.down = down
self.in_channel = in_channel
self.out_channel = out_channel
self.down_conv_1 = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
pad_mode="pad", stride=self.stride, pad=1)
self.is_output = is_output
if not is_output:
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
self.relu1 = nn.PReLU()
if self.down:
self.down_conv_2 = Conv3D(out_channel, out_channel, kernel_size=(3, 3, 3), \
pad_mode="pad", stride=1, pad=1)
self.relu2 = nn.PReLU()
if kernel_size[0] == 1:
self.residual = Conv3D(in_channel, out_channel, kernel_size=(1, 1, 1), \
pad_mode="valid", stride=self.stride)
else:
self.residual = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
pad_mode="pad", stride=self.stride, pad=1)
self.batchNormal2 = BatchNorm3d(num_features=self.out_channel)
def construct(self, x):
out = self.down_conv_1(x)
if self.is_output:
return out
out = self.batchNormal1(out)
out = self.relu1(out)
if self.down:
out = self.down_conv_2(out)
out = self.batchNormal2(out)
out = self.relu2(out)
res = self.residual(x)
else:
res = x
return out + res
class Down(nn.Cell):
def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), dtype=mstype.float16):
super().__init__()
self.stride = stride
self.in_channel = in_channel
self.out_channel = out_channel
self.down_conv = ResidualUnit(self.in_channel, self.out_channel, stride, kernel_size).to_float(dtype)
def construct(self, x):
x = self.down_conv(x)
return x
class Up(nn.Cell):
def __init__(self, in_channel, down_in_channel, out_channel, stride=2, is_output=False, dtype=mstype.float16):
super().__init__()
self.in_channel = in_channel
self.down_in_channel = down_in_channel
self.out_channel = out_channel
self.stride = stride
self.conv3d_transpose = Conv3DTranspose(in_channel=self.in_channel + self.down_in_channel, \
pad=1, out_channel=self.out_channel, kernel_size=(3, 3, 3), \
stride=self.stride, output_padding=(1, 1, 1))
self.concat = P.Concat(axis=1)
self.conv = ResidualUnit(self.out_channel, self.out_channel, stride=1, down=False, \
is_output=is_output).to_float(dtype)
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
self.relu = nn.PReLU()
def construct(self, input_data, down_input):
x = self.concat((input_data, down_input))
x = self.conv3d_transpose(x)
x = self.batchNormal1(x)
x = self.relu(x)
x = self.conv(x)
return x

View File

@ -0,0 +1,170 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import numpy as np
from src.config import config
def correct_nifti_head(img):
"""
Check nifti object header's format, update the header if needed.
In the updated image pixdim matches the affine.
Args:
img: nifti image object
"""
dim = img.header["dim"][0]
if dim >= 5:
return img
pixdim = np.asarray(img.header.get_zooms())[:dim]
norm_affine = np.sqrt(np.sum(np.square(img.affine[:dim, :dim]), 0))
if np.allclose(pixdim, norm_affine):
return img
if hasattr(img, "get_sform"):
return rectify_header_sform_qform(img)
return img
def get_random_patch(dims, patch_size, rand_fn=None):
"""
Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size`.
Args:
dims: shape of source array
patch_size: shape of patch size to generate
rand_fn: generate random numbers
Returns:
(tuple of slice): a tuple of slice objects defining the patch
"""
rand_int = np.random.randint if rand_fn is None else rand_fn.randint
min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
def first(iterable, default=None):
"""
Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.
"""
for i in iterable:
return i
return default
def _get_scan_interval(image_size, roi_size, num_image_dims, overlap):
"""
Compute scan interval according to the image size, roi size and overlap.
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
use 1 instead to make sure sliding window works.
"""
if len(image_size) != num_image_dims:
raise ValueError("image different from spatial dims.")
if len(roi_size) != num_image_dims:
raise ValueError("roi size different from spatial dims.")
scan_interval = []
for i in range(num_image_dims):
if roi_size[i] == image_size[i]:
scan_interval.append(int(roi_size[i]))
else:
interval = int(roi_size[i] * (1 - overlap))
scan_interval.append(interval if interval > 0 else 1)
return tuple(scan_interval)
def dense_patch_slices(image_size, patch_size, scan_interval):
"""
Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.
Args:
image_size: dimensions of image to iterate over
patch_size: size of patches to generate slices
scan_interval: dense patch sampling interval
Returns:
a list of slice objects defining each patch
"""
num_spatial_dims = len(image_size)
patch_size = patch_size
scan_num = []
for i in range(num_spatial_dims):
if scan_interval[i] == 0:
scan_num.append(1)
else:
num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])
scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
starts = []
for dim in range(num_spatial_dims):
dim_starts = []
for idx in range(scan_num[dim]):
start_idx = idx * scan_interval[dim]
start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
dim_starts.append(start_idx)
starts.append(dim_starts)
out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
return [(slice(None),)*2 + tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
def create_sliding_window(image, roi_size, overlap):
num_image_dims = len(image.shape) - 2
if overlap < 0 or overlap >= 1:
raise AssertionError("overlap must be >= 0 and < 1.")
image_size_temp = list(image.shape[2:])
image_size = tuple(max(image_size_temp[i], roi_size[i]) for i in range(num_image_dims))
scan_interval = _get_scan_interval(image_size, roi_size, num_image_dims, overlap)
slices = dense_patch_slices(image_size, roi_size, scan_interval)
windows_sliding = [image[slice] for slice in slices]
return windows_sliding, slices
def one_hot(labels):
N, _, D, H, W = labels.shape
labels = np.reshape(labels, (N, -1))
labels = labels.astype(np.int32)
N, K = labels.shape
one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32)
for i in range(N):
for j in range(K):
one_hot_encoding[i, labels[i][j], j] = 1
labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W))
return labels
def CalculateDice(y_pred, label):
"""
Args:
y_pred: predictions. As for classification tasks,
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
the shape should be [BNHW] or [BNHWD].
label: ground truth, the first dim is batch.
"""
y_pred_output = np.expand_dims(np.argmax(y_pred, axis=1), axis=1)
y_pred = one_hot(y_pred_output)
y = one_hot(label)
y_pred, y = ignore_background(y_pred, y)
inter = np.dot(y_pred.flatten(), y.flatten()).astype(np.float64)
union = np.dot(y_pred.flatten(), y_pred.flatten()).astype(np.float64) + np.dot(y.flatten(), \
y.flatten()).astype(np.float64)
single_dice_coeff = 2 * inter / (union + 1e-6)
return single_dice_coeff, y_pred_output
def ignore_background(y_pred, label):
"""
This function is used to remove background (the first channel) for `y_pred` and `y`.
Args:
y_pred: predictions. As for classification tasks,
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
the shape should be [BNHW] or [BNHWD].
label: ground truth, the first dim is batch.
"""
label = label[:, 1:] if label.shape[1] > 1 else label
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred
return y_pred, label

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import ast
import mindspore
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, Model, context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from src.dataset import create_dataset
from src.unet3d_model import UNet3d
from src.config import config as cfg
from src.lr_schedule import dynamic_lr
from src.loss import SoftmaxCrossEntropyWithLogits
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \
device_id=device_id)
mindspore.set_seed(1)
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet3D on images and target masks')
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
parser.add_argument('--run_distribute', dest='run_distribute', type=ast.literal_eval, default=False, \
help='Run distribute, default: false')
return parser.parse_args()
def train_net(data_dir,
seg_dir,
run_distribute,
config=None):
if run_distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=rank_size,
gradients_mean=True)
else:
rank_id = 0
rank_size = 1
train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, \
rank_size=rank_size, rank_id=rank_id, is_training=True)
train_data_size = train_dataset.get_dataset_size()
print("train dataset length is:", train_data_size)
network = UNet3d(config=config)
loss = SoftmaxCrossEntropyWithLogits()
lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32)
optimizer = nn.Adam(params=network.trainable_params(), learning_rate=lr)
scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
network.set_train()
model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager)
time_cb = TimeMonitor(data_size=train_data_size)
loss_cb = LossMonitor()
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(config.model),
directory='./ckpt_{}/'.format(device_id),
config=ckpt_config)
callbacks_list = [loss_cb, time_cb, ckpoint_cb]
print("============== Starting Training ==============")
model.train(config.epoch_size, train_dataset, callbacks=callbacks_list)
print("============== End Training ==============")
if __name__ == '__main__':
args = get_args()
print("Training setting:", args)
train_net(data_dir=args.data_url,
seg_dir=args.seg_url,
run_distribute=args.run_distribute,
config=cfg)