!16069 modify model_zoo unet3d for cloud

From: @Somnus2020
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-05-10 20:06:19 +08:00 committed by Gitee
commit 7c0ced0a87
16 changed files with 456 additions and 131 deletions

View File

@ -82,6 +82,37 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path
```
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
```python
# run distributed training on modelarts example
# (1) First, Perform a or b.
# a. Set "enable_modelarts=True" on yaml file.
# Set other parameters on yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add other parameters on the website UI interface.
# (2) Download nibabel and set pip-requirements.txt to code directory
# (3) Set the config directory to "config_path=/The path of config in S3/"
# (4) Set the code directory to "/path/unet" on the website UI interface.
# (5) Set the startup file to "train.py" on the website UI interface.
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (7) Create your job.
# run evaluation on modelarts example
# (1) Copy or upload your trained model to S3 bucket.
# (2) Perform a or b.
# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file.
# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file.
# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
# (3) Download nibabel and set pip-requirements.txt to code directory
# (4) Set the config directory to "config_path=/The path of config in S3/"
# (5) Set the code directory to "/path/unet" on the website UI interface.
# (6) Set the startup file to "eval.py" on the website UI interface.
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (8) Create your job.
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
@ -96,7 +127,6 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
├── src
│ ├──config.py // parameter configuration
│ ├──dataset.py // creating dataset
│   ├──lr_schedule.py // learning rate scheduler
│ ├──transform.py // handle dataset
@ -105,6 +135,12 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path
│ ├──utils.py // General components (callback function)
│ ├──unet3d_model.py // Unet3D model
│ ├──unet3d_parts.py // Unet3D part
├── model_utils
│ ├──config.py // parameter configuration
│ ├──device_adapter.py // device adapter
│ ├──local_adapter.py // local adapter
│ ├──moxing_adapter.py // moxing adapter
├── default_config.yaml // parameter configuration
├── train.py // training script
├── eval.py // evaluation script

View File

@ -0,0 +1,57 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
run_distribute: False
enable_profiling: False
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_target: 'Ascend'
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'Unet3d-9-877.ckpt'
# ==============================================================================
# Training options
lr: 0.0005
batch_size: 1
epoch_size: 10
warmup_step: 120
warmup_ratio: 0.3
num_classes: 4
in_channels: 1
keep_checkpoint_max: 1
loss_scale: 256.0
roi_size : [224, 224, 96]
overlap: 0.25
min_val: -500
max_val: 1000
upper_limit: 5
lower_limit: 3
# Export options
device_id: 0
ckpt_file: ""
file_name: ""
file_format: ""
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
epoch_size: "Total training epochs."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -14,32 +14,28 @@
# ============================================================================
import os
import argparse
import numpy as np
from mindspore import dtype as mstype
from mindspore import Model, context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.unet3d_model import UNet3d
from src.config import config as cfg
from src.utils import create_sliding_window, CalculateDice
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
def get_args():
parser = argparse.ArgumentParser(description='Test the UNet3D on images and target masks')
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
parser.add_argument('--ckpt_path', dest='ckpt_path', type=str, default='', help='checkpoint path')
return parser.parse_args()
def test_net(data_dir, seg_dir, ckpt_path, config=None):
eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, is_training=False)
@moxing_wrapper()
def test_net(data_path, ckpt_path):
data_dir = data_path + "/image/"
seg_dir = data_path + "/seg/"
eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, is_training=False)
eval_data_size = eval_dataset.get_dataset_size()
print("train dataset length is:", eval_data_size)
network = UNet3d(config=config)
network = UNet3d()
network.set_train(False)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
@ -70,9 +66,5 @@ def test_net(data_dir, seg_dir, ckpt_path, config=None):
print("eval average dice is {}".format(avg_dice))
if __name__ == '__main__':
args = get_args()
print("Testing setting:", args)
test_net(data_dir=args.data_url,
seg_dir=args.seg_url,
ckpt_path=args.ckpt_path,
config=cfg)
test_net(data_path=config.data_path,
ckpt_path=config.checkpoint_file_path)

View File

@ -14,7 +14,7 @@
# limitations under the License.
# ============================================================================
if [ $# -ne 3 ]
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]"
exit 1
@ -45,14 +45,6 @@ then
exit 1
fi
PATH3=$(get_real_path $3)
echo $PATH3
if [ ! -d $PATH3 ]
then
echo "error: SEG_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
@ -65,6 +57,7 @@ do
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp ../*.yaml ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
@ -73,8 +66,8 @@ do
python train.py \
--run_distribute=True \
--data_url=$PATH2 \
--seg_url=$PATH3 > log.txt 2>&1 &
--data_path=$PATH2 \
--output_path './output' > log.txt 2>&1 &
cd ../
done

View File

@ -14,7 +14,7 @@
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
if [ $# != 2 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
@ -23,7 +23,7 @@ then
echo "=============================================================================================================="
fi
if [ $# != 3 ]
if [ $# != 2 ]
then
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]"
exit 1
@ -36,22 +36,14 @@ get_real_path(){
echo "$(realpath -m $PWD/$1)"
fi
}
IMAGE_PATH=$(get_real_path $1)
SEG_PATH=$(get_real_path $2)
CHECKPOINT_FILE_PATH=$(get_real_path $3)
echo $IMAGE_PATH
echo $SEG_PATH
PATH1=$(get_real_path $1)
CHECKPOINT_FILE_PATH=$(get_real_path $2)
echo $PATH1
echo $CHECKPOINT_FILE_PATH
if [ ! -d $IMAGE_PATH ]
if [ ! -d $PATH1 ]
then
echo "error: IMAGE_PATH=$IMAGE_PATH is not a path"
exit 1
fi
if [ ! -d $SEG_PATH ]
then
echo "error: SEG_PATH=$SEG_PATH is not a path"
echo "error: PATH1=$PATH1 is not a path"
exit 1
fi
@ -74,9 +66,10 @@ fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp ../*.yaml ./eval
cp -r ../src ./eval
cd ./eval || exit
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
python eval.py --data_url=$IMAGE_PATH --seg_url=$SEG_PATH --ckpt_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 &
python eval.py --data_path=$PATH1 --checkpoint_file_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 &
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
cd ..

View File

@ -14,7 +14,7 @@
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
if [ $# -ne 1 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]"
exit 1
@ -36,14 +36,6 @@ then
exit 1
fi
PATH2=$(get_real_path $2)
echo $PATH2
if [ ! -d $PATH2 ]
then
echo "error: SEG_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
@ -54,9 +46,10 @@ rm -rf ./train
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp ../*.yaml ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --data_url=$PATH1 --seg_url=$PATH2 > train.log 2>&1 &
python train.py --data_path=$PATH1 --output_path './output' > train.log 2>&1 &
cd ..

View File

@ -17,7 +17,7 @@ import os
import argparse
from pathlib import Path
import SimpleITK as sitk
from src.config import config
from src.model_utils.config import config
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, help="Input image directory to be processed.")

View File

@ -18,7 +18,7 @@ import glob
import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.transforms.py_transforms import Compose
from src.config import config as cfg
from src.model_utils.config import config
from src.transform import Dataset, ExpandChannel, LoadData, Orientation, ScaleIntensityRange, RandomCropSamples, OneHot
class ConvertLabel:
@ -34,16 +34,16 @@ class ConvertLabel:
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
"""
data[data > cfg['upper_limit']] = 0
data = data - (cfg['lower_limit'] - 1)
data = np.clip(data, 0, cfg['lower_limit'])
data[data > config.upper_limit] = 0
data = data - (config.lower_limit - 1)
data = np.clip(data, 0, config.lower_limit)
return data
def __call__(self, image, label):
label = self.operation(label)
return image, label
def create_dataset(data_path, seg_path, config, rank_size=1, rank_id=0, is_training=True):
def create_dataset(data_path, seg_path, rank_size=1, rank_id=0, is_training=True):
seg_files = sorted(glob.glob(os.path.join(seg_path, "*.nii.gz")))
train_files = [os.path.join(data_path, os.path.basename(seg)) for seg in seg_files]
train_ds = Dataset(data=train_files, seg=seg_files)

View File

@ -17,7 +17,7 @@ import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss
from src.config import config
from src.model_utils.config import config
class SoftmaxCrossEntropyWithLogits(_Loss):
def __init__(self):
@ -27,11 +27,12 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
self.cast = P.Cast()
self.reduce_mean = P.ReduceMean()
self.num_classes = config.num_classes
def construct(self, logits, label):
logits = self.transpose(logits, (0, 2, 3, 4, 1))
label = self.transpose(label, (0, 2, 3, 4, 1))
label = self.cast(label, mstype.float32)
loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, config['num_classes'])), \
self.reshape(label, (-1, config['num_classes']))))
loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, self.num_classes)), \
self.reshape(label, (-1, self.num_classes))))
return self.get_loss(loss)

View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
_config_path = "./default_config.yaml"
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
else:
raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser, default, helper, path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -1,34 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from easydict import EasyDict
config = EasyDict({
'model': 'Unet3d',
'lr': 0.0005,
'epoch_size': 10,
'batch_size': 1,
'warmup_step': 120,
'warmup_ratio': 0.3,
'num_classes': 4,
'in_channels': 1,
'keep_checkpoint_max': 5,
'loss_scale': 256.0,
'roi_size': [224, 224, 96],
'overlap': 0.25,
'min_val': -500,
'max_val': 1000,
'upper_limit': 5,
'lower_limit': 3,
})
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from src.model_utils.config import config
if config.enable_modelarts:
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from src.model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -17,9 +17,10 @@ import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from src.unet3d_parts import Down, Up
from src.model_utils.config import config
class UNet3d(nn.Cell):
def __init__(self, config=None):
def __init__(self):
super(UNet3d, self).__init__()
self.n_channels = config.in_channels
self.n_classes = config.num_classes

View File

@ -15,7 +15,7 @@
import math
import numpy as np
from src.config import config
from src.model_utils.config import config
def correct_nifti_head(img):
"""
@ -131,11 +131,11 @@ def one_hot(labels):
labels = np.reshape(labels, (N, -1))
labels = labels.astype(np.int32)
N, K = labels.shape
one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32)
one_hot_encoding = np.zeros((N, config.num_classes, K), dtype=np.float32)
for i in range(N):
for j in range(K):
one_hot_encoding[i, labels[i][j], j] = 1
labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W))
labels = np.reshape(one_hot_encoding, (N, config.num_classes, D, H, W))
return labels
def CalculateDice(y_pred, label):

View File

@ -14,43 +14,36 @@
# ============================================================================
import os
import argparse
import ast
import mindspore
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, Model, context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.communication.management import init
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from src.dataset import create_dataset
from src.unet3d_model import UNet3d
from src.config import config as cfg
from src.lr_schedule import dynamic_lr
from src.loss import SoftmaxCrossEntropyWithLogits
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \
device_id=device_id)
mindspore.set_seed(1)
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet3D on images and target masks')
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
parser.add_argument('--run_distribute', dest='run_distribute', type=ast.literal_eval, default=False, \
help='Run distribute, default: false')
return parser.parse_args()
def train_net(data_dir,
seg_dir,
run_distribute,
config=None):
@moxing_wrapper()
def train_net(data_path,
run_distribute):
data_dir = data_path + "/image/"
seg_dir = data_path + "/seg/"
if run_distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
rank_id = get_device_id()
rank_size = get_device_num()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=rank_size,
@ -58,12 +51,12 @@ def train_net(data_dir,
else:
rank_id = 0
rank_size = 1
train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, \
train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, \
rank_size=rank_size, rank_id=rank_id, is_training=True)
train_data_size = train_dataset.get_dataset_size()
print("train dataset length is:", train_data_size)
network = UNet3d(config=config)
network = UNet3d()
loss = SoftmaxCrossEntropyWithLogits()
lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32)
@ -77,8 +70,9 @@ def train_net(data_dir,
loss_cb = LossMonitor()
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(config.model),
directory='./ckpt_{}/'.format(device_id),
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
ckpoint_cb = ModelCheckpoint(prefix='Unet3d',
directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id),
config=ckpt_config)
callbacks_list = [loss_cb, time_cb, ckpoint_cb]
print("============== Starting Training ==============")
@ -86,9 +80,5 @@ def train_net(data_dir,
print("============== End Training ==============")
if __name__ == '__main__':
args = get_args()
print("Training setting:", args)
train_net(data_dir=args.data_url,
seg_dir=args.seg_url,
run_distribute=args.run_distribute,
config=cfg)
train_net(data_path=config.data_path,
run_distribute=config.run_distribute)