forked from mindspore-Ecosystem/mindspore
!13761 add unet3D
From: @Somnus2020 Reviewed-by: @linqingke,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
f8198ece2e
|
@ -0,0 +1,263 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [Unet Description](#unet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [running on Ascend](#running-on-ascend)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## [Unet Description](#contents)
|
||||
|
||||
Unet3D model is widely used for 3D medical image segmentation. The construct of Unet3D network is similar to the Unet, the main difference is that Unet3D use 3D operations like Conv3D while Unet is anentirely 2D architecture. To know more information about Unet3D network, you can read the original paper Unet3D: Learning Dense VolumetricSegmentation from Sparse Annotation.
|
||||
|
||||
## [Model Architecture](#contents)
|
||||
|
||||
Unet3D model is created based on the previous Unet(2D), which includes an encoder part and a decoder part. The encoder part is used to analyze the whole picture and extract and analyze features, while the decoder part is to generate a segmented block image. In this model, we also add residual block in the base block to improve the network.
|
||||
|
||||
## [Dataset](#contents)
|
||||
|
||||
Dataset used: [LUNA16](https://luna16.grand-challenge.org/)
|
||||
|
||||
- Description: The data is to automatically detect the location of nodules from volumetric CT images. 888 CT scans from LIDC-IDRI database are provided. The complete dataset is divided into 10 subsets that should be used for the 10-fold cross-validation. All subsets are available as compressed zip files.
|
||||
|
||||
- Dataset size:888
|
||||
- Train:878 images
|
||||
- Test:10 images
|
||||
- Data format:zip
|
||||
- Note:Data will be processed in convert_nifti.py
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
- Select the network and dataset to use
|
||||
|
||||
```shell
|
||||
|
||||
Convert dataset into mifti format.
|
||||
python ./src/convert_nifti.py --input_path=/path/to/input_image/ --output_path=/path/to/output_image/
|
||||
|
||||
```
|
||||
|
||||
Refer to `src/config.py`. We support some parameter configurations for quick start.
|
||||
|
||||
- Run on Ascend
|
||||
|
||||
```python
|
||||
|
||||
# run training example
|
||||
python train.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ > train.log 2>&1 &
|
||||
|
||||
# run distributed training example
|
||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]
|
||||
|
||||
# run evaluation example
|
||||
python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
|
||||
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```text
|
||||
|
||||
.
|
||||
└─unet3d
|
||||
├── README.md // descriptions about Unet3D
|
||||
├── scripts
|
||||
│ ├──run_disribute_train.sh // shell script for distributed on Ascend
|
||||
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
|
||||
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
|
||||
├── src
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──lr_schedule.py // learning rate scheduler
|
||||
│ ├──transform.py // handle dataset
|
||||
│ ├──convert_nifti.py // convert dataset
|
||||
│ ├──loss.py // loss
|
||||
│ ├──conv.py // conv components
|
||||
│ ├──utils.py // General components (callback function)
|
||||
│ ├──unet3d_model.py // Unet3D model
|
||||
│ ├──unet3d_parts.py // Unet3D part
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
|
||||
- config for Unet3d, luna16 dataset
|
||||
|
||||
```python
|
||||
|
||||
'model': 'Unet3d', # model name
|
||||
'lr': 0.0005, # learning rate
|
||||
'epochs': 10, # total training epochs when run 1p
|
||||
'batchsize': 1, # training batch size
|
||||
"warmup_step": 120, # warmp up step in lr generator
|
||||
"warmup_ratio": 0.3, # warpm up ratio
|
||||
'num_classes': 4, # the number of classes in the dataset
|
||||
'in_channels': 1, # the number of channels
|
||||
'keep_checkpoint_max': 5, # only keep the last keep_checkpoint_max checkpoint
|
||||
'loss_scale': 256.0, # loss scale
|
||||
'roi_size': [224, 224, 96], # random roi size
|
||||
'overlap': 0.25, # overlap rate
|
||||
'min_val': -500, # intersity original range min
|
||||
'max_val': 1000, # intersity original range max
|
||||
'upper_limit': 5 # upper limit of num_classes
|
||||
'lower_limit': 3 # lower limit of num_classes
|
||||
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
#### running on Ascend
|
||||
|
||||
```shell
|
||||
|
||||
python train.py --data_url=/path/to/data/ -seg_url=/path/to/segment/ > train.log 2>&1 &
|
||||
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
|
||||
|
||||
```shell
|
||||
|
||||
epoch: 1 step: 878, loss is 0.55011123
|
||||
epoch time: 1443410.353 ms, per step time: 1688.199 ms
|
||||
epoch: 2 step: 878, loss is 0.58278626
|
||||
epoch time: 1172136.839 ms, per step time: 1370.920 ms
|
||||
epoch: 3 step: 878, loss is 0.43625978
|
||||
epoch time: 1135890.834 ms, per step time: 1328.537 ms
|
||||
epoch: 4 step: 878, loss is 0.06556784
|
||||
epoch time: 1180467.795 ms, per step time: 1380.664 ms
|
||||
|
||||
```
|
||||
|
||||
#### Distributed Training
|
||||
|
||||
> Notes:
|
||||
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
>
|
||||
|
||||
```shell
|
||||
|
||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]
|
||||
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file `/train_parallel[X]/log.txt`. The loss value will be achieved as follows:
|
||||
|
||||
```shell
|
||||
|
||||
epoch: 1 step: 110, loss is 0.8294426
|
||||
epoch time: 468891.643 ms, per step time: 4382.165 ms
|
||||
epoch: 2 step: 110, loss is 0.58278626
|
||||
epoch time: 165469.201 ms, per step time: 1546.441 ms
|
||||
epoch: 3 step: 110, loss is 0.43625978
|
||||
epoch time: 158915.771 ms, per step time: 1485.194 ms
|
||||
...
|
||||
epoch: 9 step: 110, loss is 0.016280059
|
||||
epoch time: 172815.179 ms, per step time: 1615.095 ms
|
||||
epoch: 10 step: 110, loss is 0.020185348
|
||||
epoch time: 140476.520 ms, per step time: 1312.865 ms
|
||||
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
- evaluation on dataset when running on Ascend
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet3d/Unet3d-10_110.ckpt".
|
||||
|
||||
```shell
|
||||
|
||||
python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
|
||||
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```shell
|
||||
|
||||
# grep "eval average dice is:" eval.log
|
||||
eval average dice is 0.9502010010453671
|
||||
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
### [Performance](#contents)
|
||||
|
||||
#### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------------------------------------- |
|
||||
| Model Version | Unet3D |
|
||||
| Resource | Ascend 910; CPU 2.60GHz,192cores;Memory,755G |
|
||||
| uploaded Date | 03/18/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | LUNA16 |
|
||||
| Training Parameters | epoch = 10, batch_size = 1 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits |
|
||||
| Speed | 8pcs: 1795ms/step |
|
||||
| Total time | 8pcs: 0.62hours |
|
||||
| Parameters (M) | 34 |
|
||||
| Scripts | [unet3d script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet3d) |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | Unet3D |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 03/18/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | LUNA16 |
|
||||
| batch_size | 1 |
|
||||
| Dice | dice = 0.9502 |
|
||||
| Model for inference | 56M(.ckpt file) |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We set seed to 1 in train.py.
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Model, context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import create_dataset
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.config import config as cfg
|
||||
from src.utils import create_sliding_window, CalculateDice
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Test the UNet3D on images and target masks')
|
||||
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
|
||||
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
|
||||
parser.add_argument('--ckpt_path', dest='ckpt_path', type=str, default='', help='checkpoint path')
|
||||
return parser.parse_args()
|
||||
|
||||
def test_net(data_dir, seg_dir, ckpt_path, config=None):
|
||||
eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, is_training=False)
|
||||
eval_data_size = eval_dataset.get_dataset_size()
|
||||
print("train dataset length is:", eval_data_size)
|
||||
|
||||
network = UNet3d(config=config)
|
||||
network.set_train(False)
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
model = Model(network)
|
||||
index = 0
|
||||
total_dice = 0
|
||||
for batch in eval_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = batch["image"]
|
||||
seg = batch["seg"]
|
||||
print("current image shape is {}".format(image.shape), flush=True)
|
||||
sliding_window_list, slice_list = create_sliding_window(image, config.roi_size, config.overlap)
|
||||
image_size = (config.batch_size, config.num_classes) + image.shape[2:]
|
||||
output_image = np.zeros(image_size, np.float32)
|
||||
count_map = np.zeros(image_size, np.float32)
|
||||
importance_map = np.ones(config.roi_size, np.float32)
|
||||
for window, slice_ in zip(sliding_window_list, slice_list):
|
||||
window_image = Tensor(window, mstype.float32)
|
||||
pred_probs = model.predict(window_image)
|
||||
output_image[slice_] += pred_probs.asnumpy()
|
||||
count_map[slice_] += importance_map
|
||||
output_image = output_image / count_map
|
||||
dice, _ = CalculateDice(output_image, seg)
|
||||
print("The {} batch dice is {}".format(index, dice), flush=True)
|
||||
total_dice += dice
|
||||
index = index + 1
|
||||
avg_dice = total_dice / eval_data_size
|
||||
print("**********************End Eval***************************************")
|
||||
print("eval average dice is {}".format(avg_dice))
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
print("Testing setting:", args)
|
||||
test_net(data_dir=args.data_url,
|
||||
seg_dir=args.seg_url,
|
||||
ckpt_path=args.ckpt_path,
|
||||
config=cfg)
|
|
@ -0,0 +1,80 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH3=$(get_real_path $3)
|
||||
echo $PATH3
|
||||
if [ ! -d $PATH3 ]
|
||||
then
|
||||
echo "error: SEG_PATH=$PATH3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python train.py \
|
||||
--run_distribute=True \
|
||||
--data_url=$PATH2 \
|
||||
--seg_url=$PATH3 > log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,82 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
|
||||
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
|
||||
echo "=============================================================================================================="
|
||||
fi
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
IMAGE_PATH=$(get_real_path $1)
|
||||
SEG_PATH=$(get_real_path $2)
|
||||
CHECKPOINT_FILE_PATH=$(get_real_path $3)
|
||||
echo $IMAGE_PATH
|
||||
echo $SEG_PATH
|
||||
echo $CHECKPOINT_FILE_PATH
|
||||
|
||||
if [ ! -d $IMAGE_PATH ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$IMAGE_PATH is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $SEG_PATH ]
|
||||
then
|
||||
echo "error: SEG_PATH=$SEG_PATH is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $CHECKPOINT_FILE_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_FILE_PATH=$CHECKPOINT_FILE_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval
|
||||
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
python eval.py --data_url=$IMAGE_PATH --seg_url=$SEG_PATH --ckpt_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 &
|
||||
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
cd ..
|
|
@ -0,0 +1,62 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: SEG_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_url=$PATH1 --seg_url=$PATH2 > train.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from easydict import EasyDict
|
||||
config = EasyDict({
|
||||
'model': 'Unet3d',
|
||||
'lr': 0.0005,
|
||||
'epoch_size': 10,
|
||||
'batch_size': 1,
|
||||
'warmup_step': 120,
|
||||
'warmup_ratio': 0.3,
|
||||
'num_classes': 4,
|
||||
'in_channels': 1,
|
||||
'keep_checkpoint_max': 5,
|
||||
'loss_scale': 256.0,
|
||||
'roi_size': [224, 224, 96],
|
||||
'overlap': 0.25,
|
||||
'min_val': -500,
|
||||
'max_val': 1000,
|
||||
'upper_limit': 5,
|
||||
'lower_limit': 3,
|
||||
})
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
def weight_variable(shape):
|
||||
init_value = initializer('Normal', shape, mstype.float32)
|
||||
return Parameter(init_value)
|
||||
|
||||
class Conv3D(nn.Cell):
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1,
|
||||
data_format="NCDHW",
|
||||
bias_init="zeros",
|
||||
has_bias=True):
|
||||
super().__init__()
|
||||
self.weight_shape = (out_channel, in_channel, kernel_size[0], kernel_size[1], kernel_size[2])
|
||||
self.weight = weight_variable(self.weight_shape)
|
||||
self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, \
|
||||
pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation, \
|
||||
group=group, data_format=data_format)
|
||||
self.bias_init = bias_init
|
||||
self.has_bias = has_bias
|
||||
self.bias_add = P.BiasAdd(data_format=data_format)
|
||||
if self.has_bias:
|
||||
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
||||
|
||||
class Conv3DTranspose(nn.Cell):
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1,
|
||||
output_padding=0,
|
||||
data_format="NCDHW",
|
||||
bias_init="zeros",
|
||||
has_bias=True):
|
||||
super().__init__()
|
||||
self.weight_shape = (in_channel, out_channel, kernel_size[0], kernel_size[1], kernel_size[2])
|
||||
self.weight = weight_variable(self.weight_shape)
|
||||
self.conv_transpose = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel,\
|
||||
kernel_size=kernel_size, mode=mode, pad=pad, stride=stride, \
|
||||
dilation=dilation, group=group, output_padding=output_padding, \
|
||||
data_format=data_format)
|
||||
self.bias_init = bias_init
|
||||
self.has_bias = has_bias
|
||||
self.bias_add = P.BiasAdd(data_format=data_format)
|
||||
if self.has_bias:
|
||||
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv_transpose(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import SimpleITK as sitk
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_path", type=str, help="Input image directory to be processed.")
|
||||
parser.add_argument("--output_path", type=str, help="Output file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
def get_list_of_files_in_dir(directory, file_types='*'):
|
||||
"""
|
||||
Get list of certain format files.
|
||||
|
||||
Args:
|
||||
directory (str): The input directory for image.
|
||||
file_types (str): The file_types to filter the files.
|
||||
"""
|
||||
return [f for f in Path(directory).glob(file_types) if f.is_file()]
|
||||
|
||||
def convert_nifti(input_dir, output_dir, roi_size, file_types):
|
||||
"""
|
||||
Convert dataset into mifti format.
|
||||
|
||||
Args:
|
||||
input_dir (str): The input directory for image.
|
||||
output_dir (str): The output directory to save nifti format data.
|
||||
roi_size (str): The size to crop the image.
|
||||
file_types: File types to convert into nifti.
|
||||
"""
|
||||
file_list = get_list_of_files_in_dir(input_dir, file_types)
|
||||
for file_name in file_list:
|
||||
file_name = str(file_name)
|
||||
input_file_name, _ = os.path.splitext(os.path.basename(file_name))
|
||||
img = sitk.ReadImage(file_name)
|
||||
image_array = sitk.GetArrayFromImage(img)
|
||||
D, H, W = image_array.shape
|
||||
if H < roi_size[0] or W < roi_size[1] or D < roi_size[2]:
|
||||
print("file {} size is smaller than roi size, ignore it.".format(input_file_name))
|
||||
continue
|
||||
output_path = os.path.join(output_dir, input_file_name + ".nii.gz")
|
||||
sitk.WriteImage(img, output_path)
|
||||
print("create output file {} success.".format(output_path))
|
||||
|
||||
if __name__ == '__main__':
|
||||
convert_nifti(args.input_path, args.output_path, config.roi_size, "*.mhd")
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =========================================================================
|
||||
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.dataset.transforms.py_transforms import Compose
|
||||
from src.config import config as cfg
|
||||
from src.transform import Dataset, ExpandChannel, LoadData, Orientation, ScaleIntensityRange, RandomCropSamples, OneHot
|
||||
|
||||
class ConvertLabel:
|
||||
"""
|
||||
Crop at the center of image with specified ROI size.
|
||||
|
||||
Args:
|
||||
roi_size: the spatial size of the crop region e.g. [224,224,128]
|
||||
If its components have non-positive values, the corresponding size of input image will be used.
|
||||
"""
|
||||
def operation(self, data):
|
||||
"""
|
||||
Apply the transform to `img`, assuming `img` is channel-first and
|
||||
slicing doesn't apply to the channel dim.
|
||||
"""
|
||||
data[data > cfg['upper_limit']] = 0
|
||||
data = data - (cfg['lower_limit'] - 1)
|
||||
data = np.clip(data, 0, cfg['lower_limit'])
|
||||
return data
|
||||
|
||||
def __call__(self, image, label):
|
||||
label = self.operation(label)
|
||||
return image, label
|
||||
|
||||
def create_dataset(data_path, seg_path, config, rank_size=1, rank_id=0, is_training=True):
|
||||
seg_files = sorted(glob.glob(os.path.join(seg_path, "*.nii.gz")))
|
||||
train_files = [os.path.join(data_path, os.path.basename(seg)) for seg in seg_files]
|
||||
train_ds = Dataset(data=train_files, seg=seg_files)
|
||||
train_loader = ds.GeneratorDataset(train_ds, column_names=["image", "seg"], num_parallel_workers=4, \
|
||||
shuffle=is_training, num_shards=rank_size, shard_id=rank_id)
|
||||
|
||||
if is_training:
|
||||
transform_image = Compose([LoadData(),
|
||||
ExpandChannel(),
|
||||
Orientation(),
|
||||
ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
|
||||
tgt_max=1.0, is_clip=True),
|
||||
RandomCropSamples(roi_size=config.roi_size, num_samples=2),
|
||||
ConvertLabel(),
|
||||
OneHot(num_classes=config.num_classes)])
|
||||
else:
|
||||
transform_image = Compose([LoadData(),
|
||||
ExpandChannel(),
|
||||
Orientation(),
|
||||
ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
|
||||
tgt_max=1.0, is_clip=True),
|
||||
ConvertLabel()])
|
||||
|
||||
train_loader = train_loader.map(operations=transform_image, input_columns=["image", "seg"], num_parallel_workers=12,
|
||||
python_multiprocessing=True)
|
||||
if not is_training:
|
||||
train_loader = train_loader.batch(1)
|
||||
return train_loader
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from src.config import config
|
||||
|
||||
class SoftmaxCrossEntropyWithLogits(_Loss):
|
||||
def __init__(self):
|
||||
super(SoftmaxCrossEntropyWithLogits, self).__init__()
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
self.cast = P.Cast()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
|
||||
def construct(self, logits, label):
|
||||
logits = self.transpose(logits, (0, 2, 3, 4, 1))
|
||||
label = self.transpose(label, (0, 2, 3, 4, 1))
|
||||
label = self.cast(label, mstype.float32)
|
||||
loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, config['num_classes'])), \
|
||||
self.reshape(label, (-1, config['num_classes']))))
|
||||
return self.get_loss(loss)
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import math
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(config, base_step):
|
||||
"""dynamic learning rate generator"""
|
||||
base_lr = config.lr
|
||||
total_steps = int(base_step * config.epoch_size)
|
||||
warmup_steps = config.warmup_step
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
return lr
|
|
@ -0,0 +1,266 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =========================================================================
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
import nibabel as nib
|
||||
from src.utils import correct_nifti_head, get_random_patch
|
||||
|
||||
MAX_SEED = np.iinfo(np.uint32).max + 1
|
||||
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
A generic dataset with a length property and an optional callable data transform
|
||||
when fetching a data sample.
|
||||
|
||||
Args:
|
||||
data: input data to load and transform to generate dataset for model.
|
||||
seg: segment data to load and transform to generate dataset for model
|
||||
"""
|
||||
def __init__(self, data, seg):
|
||||
self.data = data
|
||||
self.seg = seg
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = self.data[index]
|
||||
seg = self.seg[index]
|
||||
return [data], [seg]
|
||||
|
||||
class LoadData:
|
||||
"""
|
||||
Load Image data from provided files.
|
||||
"""
|
||||
def __init__(self, canonical=False, dtype=np.float32):
|
||||
"""
|
||||
Args:
|
||||
canonical: if True, load the image as closest to canonical axis format.
|
||||
dtype: convert the loaded image to this data type.
|
||||
"""
|
||||
self.canonical = canonical
|
||||
self.dtype = dtype
|
||||
|
||||
def operation(self, filename):
|
||||
"""
|
||||
Args:
|
||||
filename: path file or file-like object or a list of files.
|
||||
"""
|
||||
img_array = list()
|
||||
compatible_meta = dict()
|
||||
filename = str(filename, encoding="utf-8")
|
||||
filename = [filename]
|
||||
for name in filename:
|
||||
img = nib.load(name)
|
||||
img = correct_nifti_head(img)
|
||||
header = dict(img.header)
|
||||
header["filename_or_obj"] = name
|
||||
header["affine"] = img.affine
|
||||
header["original_affine"] = img.affine.copy()
|
||||
header["canonical"] = self.canonical
|
||||
ndim = img.header["dim"][0]
|
||||
spatial_rank = min(ndim, 3)
|
||||
header["spatial_shape"] = img.header["dim"][1 : spatial_rank + 1]
|
||||
if self.canonical:
|
||||
img = nib.as_closest_canonical(img)
|
||||
header["affine"] = img.affine
|
||||
img_array.append(np.array(img.get_fdata(dtype=self.dtype)))
|
||||
img.uncache()
|
||||
if not compatible_meta:
|
||||
for meta_key in header:
|
||||
meta_datum = header[meta_key]
|
||||
if isinstance(meta_datum, np.ndarray) \
|
||||
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
|
||||
continue
|
||||
compatible_meta[meta_key] = meta_datum
|
||||
else:
|
||||
assert np.allclose(header["affine"], compatible_meta["affine"])
|
||||
|
||||
img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
|
||||
return img_array
|
||||
|
||||
def __call__(self, filename1, filename2):
|
||||
img_array = self.operation(filename1)
|
||||
seg_array = self.operation(filename2)
|
||||
return img_array, seg_array
|
||||
|
||||
|
||||
class ExpandChannel:
|
||||
"""
|
||||
Expand a 1-length channel dimension to the input image.
|
||||
"""
|
||||
def operation(self, data):
|
||||
"""
|
||||
Args:
|
||||
data(numpy.array): input data to expand channel.
|
||||
"""
|
||||
return data[None]
|
||||
|
||||
def __call__(self, img, label):
|
||||
img_array = self.operation(img)
|
||||
seg_array = self.operation(label)
|
||||
return img_array, seg_array
|
||||
|
||||
|
||||
class Orientation:
|
||||
"""
|
||||
Change the input image's orientation into the specified based on `ax`.
|
||||
"""
|
||||
def __init__(self, ax="RAS", labels=tuple(zip("LPI", "RAS"))):
|
||||
"""
|
||||
Args:
|
||||
ax: N elements sequence for ND input's orientation.
|
||||
labels: optional, None or sequence of (2,) sequences
|
||||
(2,) sequences are labels for (beginning, end) of output axis.
|
||||
"""
|
||||
self.ax = ax
|
||||
self.labels = labels
|
||||
|
||||
def operation(self, data, affine=None):
|
||||
"""
|
||||
original orientation of `data` is defined by `affine`.
|
||||
|
||||
Args:
|
||||
data: in shape (num_channels, H[, W, ...]).
|
||||
affine (matrix): (N+1)x(N+1) original affine matrix for spatially ND `data`. Defaults to identity.
|
||||
|
||||
Returns:
|
||||
data (reoriented in `self.ax`), original ax, current ax.
|
||||
"""
|
||||
if data.ndim <= 1:
|
||||
raise ValueError("data must have at least one spatial dimension.")
|
||||
if affine is None:
|
||||
affine = np.eye(data.ndim, dtype=np.float64)
|
||||
affine_copy = affine
|
||||
else:
|
||||
affine_copy = to_affine_nd(data.ndim-1, affine)
|
||||
src = nib.io_orientation(affine_copy)
|
||||
dst = nib.orientations.axcodes2ornt(self.ax[:data.ndim-1], labels=self.labels)
|
||||
spatial_ornt = nib.orientations.ornt_transform(src, dst)
|
||||
ornt = spatial_ornt.copy()
|
||||
ornt[:, 0] += 1
|
||||
ornt = np.concatenate([np.array([[0, 1]]), ornt])
|
||||
data = nib.orientations.apply_orientation(data, ornt)
|
||||
return data
|
||||
|
||||
def __call__(self, img, label):
|
||||
img_array = self.operation(img)
|
||||
seg_array = self.operation(label)
|
||||
return img_array, seg_array
|
||||
|
||||
|
||||
class ScaleIntensityRange:
|
||||
"""
|
||||
Apply specific intensity scaling to the whole numpy array.
|
||||
Scaling from [src_min, src_max] to [tgt_min, tgt_max] with clip option.
|
||||
|
||||
Args:
|
||||
src_min: intensity original range min.
|
||||
src_max: intensity original range max.
|
||||
tgt_min: intensity target range min.
|
||||
tgt_max: intensity target range max.
|
||||
is_clip: whether to clip after scaling.
|
||||
"""
|
||||
def __init__(self, src_min, src_max, tgt_min, tgt_max, is_clip=False):
|
||||
self.src_min = src_min
|
||||
self.src_max = src_max
|
||||
self.tgt_min = tgt_min
|
||||
self.tgt_max = tgt_max
|
||||
self.is_clip = is_clip
|
||||
|
||||
def operation(self, data):
|
||||
if self.src_max - self.src_min == 0.0:
|
||||
logger.warning("Divide by zero (src_min == src_max)")
|
||||
return data - self.src_min + self.tgt_min
|
||||
data = (data - self.src_min) / (self.src_max - self.src_min)
|
||||
data = data * (self.tgt_max - self.tgt_min) + self.tgt_min
|
||||
if self.is_clip:
|
||||
data = np.clip(data, self.tgt_min, self.tgt_max)
|
||||
return data
|
||||
|
||||
def __call__(self, image, label):
|
||||
image = self.operation(image)
|
||||
return image, label
|
||||
|
||||
|
||||
class RandomCropSamples:
|
||||
"""
|
||||
Random crop 3d image.
|
||||
|
||||
Args:
|
||||
keys: keys of the corresponding items to be transformed.
|
||||
roi_size: if `random_size` is True, it specifies the minimum crop region.
|
||||
num_samples: the amount of crop images.
|
||||
"""
|
||||
def __init__(self, roi_size, num_samples=1):
|
||||
self.roi_size = roi_size
|
||||
self.num_samples = num_samples
|
||||
self.set_random_state(0)
|
||||
|
||||
def set_random_state(self, seed=None):
|
||||
"""
|
||||
Set the random seed to control the slice size.
|
||||
|
||||
Args:
|
||||
seed: set the random state with an integer seed.
|
||||
"""
|
||||
if seed is not None:
|
||||
_seed = seed % MAX_SEED
|
||||
self.rand_fn = np.random.RandomState(_seed)
|
||||
else:
|
||||
self.rand_fn = np.random.RandomState()
|
||||
return self
|
||||
|
||||
def get_random_slice(self, img_size):
|
||||
slices = (slice(None),) + get_random_patch(img_size, self.roi_size, self.rand_fn)
|
||||
return slices
|
||||
|
||||
def __call__(self, image, label):
|
||||
res_image = []
|
||||
res_label = []
|
||||
for _ in range(self.num_samples):
|
||||
slices = self.get_random_slice(image.shape[1:])
|
||||
img = image[slices]
|
||||
label_crop = label[slices]
|
||||
res_image.append(img)
|
||||
res_label.append(label_crop)
|
||||
return np.array(res_image), np.array(res_label)
|
||||
|
||||
class OneHot:
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
|
||||
def one_hot(self, labels):
|
||||
N, K = labels.shape
|
||||
one_hot_encoding = np.zeros((N, self.num_classes, K), dtype=np.float32)
|
||||
for i in range(N):
|
||||
for j in range(K):
|
||||
one_hot_encoding[i, labels[i][j], j] = 1
|
||||
return one_hot_encoding
|
||||
|
||||
def operation(self, labels):
|
||||
N, _, D, H, W = labels.shape
|
||||
labels = labels.astype(np.int32)
|
||||
labels = np.reshape(labels, (N, -1))
|
||||
labels = self.one_hot(labels)
|
||||
labels = np.reshape(labels, (N, self.num_classes, D, H, W))
|
||||
return labels
|
||||
|
||||
def __call__(self, image, label):
|
||||
label = self.operation(label)
|
||||
return image, label
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from src.unet3d_parts import Down, Up
|
||||
|
||||
class UNet3d(nn.Cell):
|
||||
def __init__(self, config=None):
|
||||
super(UNet3d, self).__init__()
|
||||
self.n_channels = config.in_channels
|
||||
self.n_classes = config.num_classes
|
||||
|
||||
# down
|
||||
self.transpose = P.Transpose()
|
||||
self.down1 = Down(in_channel=self.n_channels, out_channel=16, dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.down2 = Down(in_channel=16, out_channel=32, dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.down3 = Down(in_channel=32, out_channel=64, dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.down4 = Down(in_channel=64, out_channel=128, dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.down5 = Down(in_channel=128, out_channel=256, stride=1, kernel_size=(1, 1, 1), \
|
||||
dtype=mstype.float16).to_float(mstype.float16)
|
||||
|
||||
# up
|
||||
self.up1 = Up(in_channel=256, down_in_channel=128, out_channel=64, \
|
||||
dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.up2 = Up(in_channel=64, down_in_channel=64, out_channel=32, \
|
||||
dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.up3 = Up(in_channel=32, down_in_channel=32, out_channel=16, \
|
||||
dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.up4 = Up(in_channel=16, down_in_channel=16, out_channel=self.n_classes, \
|
||||
dtype=mstype.float16, is_output=True).to_float(mstype.float16)
|
||||
|
||||
self.cast = P.Cast()
|
||||
|
||||
|
||||
def construct(self, input_data):
|
||||
input_data = self.cast(input_data, mstype.float16)
|
||||
x1 = self.down1(input_data)
|
||||
x2 = self.down2(x1)
|
||||
x3 = self.down3(x2)
|
||||
x4 = self.down4(x3)
|
||||
x5 = self.down5(x4)
|
||||
|
||||
x = self.up1(x5, x4)
|
||||
x = self.up2(x, x3)
|
||||
x = self.up3(x, x2)
|
||||
x = self.up4(x, x1)
|
||||
x = self.cast(x, mstype.float32)
|
||||
return x
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from src.conv import Conv3D, Conv3DTranspose
|
||||
|
||||
class BatchNorm3d(nn.Cell):
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.bn2d = nn.BatchNorm2d(num_features, data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
||||
bn2d_out = self.bn2d(x)
|
||||
bn3d_out = self.reshape(bn2d_out, x_shape)
|
||||
return bn3d_out
|
||||
|
||||
class ResidualUnit(nn.Cell):
|
||||
def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), down=True, is_output=False):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.down = down
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.down_conv_1 = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||
pad_mode="pad", stride=self.stride, pad=1)
|
||||
self.is_output = is_output
|
||||
if not is_output:
|
||||
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
|
||||
self.relu1 = nn.PReLU()
|
||||
if self.down:
|
||||
self.down_conv_2 = Conv3D(out_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||
pad_mode="pad", stride=1, pad=1)
|
||||
self.relu2 = nn.PReLU()
|
||||
if kernel_size[0] == 1:
|
||||
self.residual = Conv3D(in_channel, out_channel, kernel_size=(1, 1, 1), \
|
||||
pad_mode="valid", stride=self.stride)
|
||||
else:
|
||||
self.residual = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||
pad_mode="pad", stride=self.stride, pad=1)
|
||||
self.batchNormal2 = BatchNorm3d(num_features=self.out_channel)
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
out = self.down_conv_1(x)
|
||||
if self.is_output:
|
||||
return out
|
||||
out = self.batchNormal1(out)
|
||||
out = self.relu1(out)
|
||||
if self.down:
|
||||
out = self.down_conv_2(out)
|
||||
out = self.batchNormal2(out)
|
||||
out = self.relu2(out)
|
||||
res = self.residual(x)
|
||||
else:
|
||||
res = x
|
||||
return out + res
|
||||
|
||||
class Down(nn.Cell):
|
||||
def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), dtype=mstype.float16):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.down_conv = ResidualUnit(self.in_channel, self.out_channel, stride, kernel_size).to_float(dtype)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.down_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Up(nn.Cell):
|
||||
def __init__(self, in_channel, down_in_channel, out_channel, stride=2, is_output=False, dtype=mstype.float16):
|
||||
super().__init__()
|
||||
self.in_channel = in_channel
|
||||
self.down_in_channel = down_in_channel
|
||||
self.out_channel = out_channel
|
||||
self.stride = stride
|
||||
self.conv3d_transpose = Conv3DTranspose(in_channel=self.in_channel + self.down_in_channel, \
|
||||
pad=1, out_channel=self.out_channel, kernel_size=(3, 3, 3), \
|
||||
stride=self.stride, output_padding=(1, 1, 1))
|
||||
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.conv = ResidualUnit(self.out_channel, self.out_channel, stride=1, down=False, \
|
||||
is_output=is_output).to_float(dtype)
|
||||
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
|
||||
self.relu = nn.PReLU()
|
||||
|
||||
def construct(self, input_data, down_input):
|
||||
x = self.concat((input_data, down_input))
|
||||
x = self.conv3d_transpose(x)
|
||||
x = self.batchNormal1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv(x)
|
||||
return x
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
|
||||
def correct_nifti_head(img):
|
||||
"""
|
||||
Check nifti object header's format, update the header if needed.
|
||||
In the updated image pixdim matches the affine.
|
||||
|
||||
Args:
|
||||
img: nifti image object
|
||||
"""
|
||||
dim = img.header["dim"][0]
|
||||
if dim >= 5:
|
||||
return img
|
||||
pixdim = np.asarray(img.header.get_zooms())[:dim]
|
||||
norm_affine = np.sqrt(np.sum(np.square(img.affine[:dim, :dim]), 0))
|
||||
if np.allclose(pixdim, norm_affine):
|
||||
return img
|
||||
if hasattr(img, "get_sform"):
|
||||
return rectify_header_sform_qform(img)
|
||||
return img
|
||||
|
||||
def get_random_patch(dims, patch_size, rand_fn=None):
|
||||
"""
|
||||
Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size`.
|
||||
|
||||
Args:
|
||||
dims: shape of source array
|
||||
patch_size: shape of patch size to generate
|
||||
rand_fn: generate random numbers
|
||||
|
||||
Returns:
|
||||
(tuple of slice): a tuple of slice objects defining the patch
|
||||
"""
|
||||
rand_int = np.random.randint if rand_fn is None else rand_fn.randint
|
||||
min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
|
||||
return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
|
||||
|
||||
|
||||
def first(iterable, default=None):
|
||||
"""
|
||||
Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.
|
||||
"""
|
||||
for i in iterable:
|
||||
return i
|
||||
return default
|
||||
|
||||
def _get_scan_interval(image_size, roi_size, num_image_dims, overlap):
|
||||
"""
|
||||
Compute scan interval according to the image size, roi size and overlap.
|
||||
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
|
||||
use 1 instead to make sure sliding window works.
|
||||
"""
|
||||
if len(image_size) != num_image_dims:
|
||||
raise ValueError("image different from spatial dims.")
|
||||
if len(roi_size) != num_image_dims:
|
||||
raise ValueError("roi size different from spatial dims.")
|
||||
|
||||
scan_interval = []
|
||||
for i in range(num_image_dims):
|
||||
if roi_size[i] == image_size[i]:
|
||||
scan_interval.append(int(roi_size[i]))
|
||||
else:
|
||||
interval = int(roi_size[i] * (1 - overlap))
|
||||
scan_interval.append(interval if interval > 0 else 1)
|
||||
return tuple(scan_interval)
|
||||
|
||||
def dense_patch_slices(image_size, patch_size, scan_interval):
|
||||
"""
|
||||
Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.
|
||||
|
||||
Args:
|
||||
image_size: dimensions of image to iterate over
|
||||
patch_size: size of patches to generate slices
|
||||
scan_interval: dense patch sampling interval
|
||||
|
||||
Returns:
|
||||
a list of slice objects defining each patch
|
||||
"""
|
||||
num_spatial_dims = len(image_size)
|
||||
patch_size = patch_size
|
||||
scan_num = []
|
||||
for i in range(num_spatial_dims):
|
||||
if scan_interval[i] == 0:
|
||||
scan_num.append(1)
|
||||
else:
|
||||
num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
|
||||
scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])
|
||||
scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
|
||||
starts = []
|
||||
for dim in range(num_spatial_dims):
|
||||
dim_starts = []
|
||||
for idx in range(scan_num[dim]):
|
||||
start_idx = idx * scan_interval[dim]
|
||||
start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
|
||||
dim_starts.append(start_idx)
|
||||
starts.append(dim_starts)
|
||||
out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
|
||||
return [(slice(None),)*2 + tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
|
||||
|
||||
def create_sliding_window(image, roi_size, overlap):
|
||||
num_image_dims = len(image.shape) - 2
|
||||
if overlap < 0 or overlap >= 1:
|
||||
raise AssertionError("overlap must be >= 0 and < 1.")
|
||||
image_size_temp = list(image.shape[2:])
|
||||
image_size = tuple(max(image_size_temp[i], roi_size[i]) for i in range(num_image_dims))
|
||||
|
||||
scan_interval = _get_scan_interval(image_size, roi_size, num_image_dims, overlap)
|
||||
slices = dense_patch_slices(image_size, roi_size, scan_interval)
|
||||
windows_sliding = [image[slice] for slice in slices]
|
||||
return windows_sliding, slices
|
||||
|
||||
def one_hot(labels):
|
||||
N, _, D, H, W = labels.shape
|
||||
labels = np.reshape(labels, (N, -1))
|
||||
labels = labels.astype(np.int32)
|
||||
N, K = labels.shape
|
||||
one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32)
|
||||
for i in range(N):
|
||||
for j in range(K):
|
||||
one_hot_encoding[i, labels[i][j], j] = 1
|
||||
labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W))
|
||||
return labels
|
||||
|
||||
def CalculateDice(y_pred, label):
|
||||
"""
|
||||
Args:
|
||||
y_pred: predictions. As for classification tasks,
|
||||
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
|
||||
the shape should be [BNHW] or [BNHWD].
|
||||
label: ground truth, the first dim is batch.
|
||||
"""
|
||||
y_pred_output = np.expand_dims(np.argmax(y_pred, axis=1), axis=1)
|
||||
y_pred = one_hot(y_pred_output)
|
||||
y = one_hot(label)
|
||||
y_pred, y = ignore_background(y_pred, y)
|
||||
inter = np.dot(y_pred.flatten(), y.flatten()).astype(np.float64)
|
||||
union = np.dot(y_pred.flatten(), y_pred.flatten()).astype(np.float64) + np.dot(y.flatten(), \
|
||||
y.flatten()).astype(np.float64)
|
||||
single_dice_coeff = 2 * inter / (union + 1e-6)
|
||||
return single_dice_coeff, y_pred_output
|
||||
|
||||
def ignore_background(y_pred, label):
|
||||
"""
|
||||
This function is used to remove background (the first channel) for `y_pred` and `y`.
|
||||
Args:
|
||||
y_pred: predictions. As for classification tasks,
|
||||
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
|
||||
the shape should be [BNHW] or [BNHWD].
|
||||
label: ground truth, the first dim is batch.
|
||||
"""
|
||||
label = label[:, 1:] if label.shape[1] > 1 else label
|
||||
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred
|
||||
return y_pred, label
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, Model, context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from src.dataset import create_dataset
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.config import config as cfg
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.loss import SoftmaxCrossEntropyWithLogits
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \
|
||||
device_id=device_id)
|
||||
mindspore.set_seed(1)
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Train the UNet3D on images and target masks')
|
||||
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
|
||||
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
|
||||
parser.add_argument('--run_distribute', dest='run_distribute', type=ast.literal_eval, default=False, \
|
||||
help='Run distribute, default: false')
|
||||
return parser.parse_args()
|
||||
|
||||
def train_net(data_dir,
|
||||
seg_dir,
|
||||
run_distribute,
|
||||
config=None):
|
||||
if run_distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||
device_num=rank_size,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
rank_id = 0
|
||||
rank_size = 1
|
||||
train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, \
|
||||
rank_size=rank_size, rank_id=rank_id, is_training=True)
|
||||
train_data_size = train_dataset.get_dataset_size()
|
||||
print("train dataset length is:", train_data_size)
|
||||
|
||||
network = UNet3d(config=config)
|
||||
|
||||
loss = SoftmaxCrossEntropyWithLogits()
|
||||
lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32)
|
||||
optimizer = nn.Adam(params=network.trainable_params(), learning_rate=lr)
|
||||
scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
network.set_train()
|
||||
|
||||
model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager)
|
||||
|
||||
time_cb = TimeMonitor(data_size=train_data_size)
|
||||
loss_cb = LossMonitor()
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(config.model),
|
||||
directory='./ckpt_{}/'.format(device_id),
|
||||
config=ckpt_config)
|
||||
callbacks_list = [loss_cb, time_cb, ckpoint_cb]
|
||||
print("============== Starting Training ==============")
|
||||
model.train(config.epoch_size, train_dataset, callbacks=callbacks_list)
|
||||
print("============== End Training ==============")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
print("Training setting:", args)
|
||||
train_net(data_dir=args.data_url,
|
||||
seg_dir=args.seg_url,
|
||||
run_distribute=args.run_distribute,
|
||||
config=cfg)
|
Loading…
Reference in New Issue