forked from mindspore-Ecosystem/mindspore
!16069 modify model_zoo unet3d for cloud
From: @Somnus2020 Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34
This commit is contained in:
commit
7c0ced0a87
|
@ -82,6 +82,37 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path
|
|||
|
||||
```
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
|
||||
|
||||
```python
|
||||
# run distributed training on modelarts example
|
||||
# (1) First, Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on yaml file.
|
||||
# Set other parameters on yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Download nibabel and set pip-requirements.txt to code directory
|
||||
# (3) Set the config directory to "config_path=/The path of config in S3/"
|
||||
# (4) Set the code directory to "/path/unet" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
|
||||
# run evaluation on modelarts example
|
||||
# (1) Copy or upload your trained model to S3 bucket.
|
||||
# (2) Perform a or b.
|
||||
# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file.
|
||||
# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file.
|
||||
# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||
# (3) Download nibabel and set pip-requirements.txt to code directory
|
||||
# (4) Set the config directory to "config_path=/The path of config in S3/"
|
||||
# (5) Set the code directory to "/path/unet" on the website UI interface.
|
||||
# (6) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
@ -96,7 +127,6 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path
|
|||
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
|
||||
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
|
||||
├── src
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──lr_schedule.py // learning rate scheduler
|
||||
│ ├──transform.py // handle dataset
|
||||
|
@ -105,6 +135,12 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path
|
|||
│ ├──utils.py // General components (callback function)
|
||||
│ ├──unet3d_model.py // Unet3D model
|
||||
│ ├──unet3d_parts.py // Unet3D part
|
||||
├── model_utils
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──device_adapter.py // device adapter
|
||||
│ ├──local_adapter.py // local adapter
|
||||
│ ├──moxing_adapter.py // moxing adapter
|
||||
├── default_config.yaml // parameter configuration
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
run_distribute: False
|
||||
enable_profiling: False
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path/"
|
||||
device_target: 'Ascend'
|
||||
checkpoint_path: './checkpoint/'
|
||||
checkpoint_file_path: 'Unet3d-9-877.ckpt'
|
||||
|
||||
# ==============================================================================
|
||||
# Training options
|
||||
lr: 0.0005
|
||||
batch_size: 1
|
||||
epoch_size: 10
|
||||
warmup_step: 120
|
||||
warmup_ratio: 0.3
|
||||
num_classes: 4
|
||||
in_channels: 1
|
||||
keep_checkpoint_max: 1
|
||||
loss_scale: 256.0
|
||||
roi_size : [224, 224, 96]
|
||||
overlap: 0.25
|
||||
min_val: -500
|
||||
max_val: 1000
|
||||
upper_limit: 5
|
||||
lower_limit: 3
|
||||
|
||||
# Export options
|
||||
device_id: 0
|
||||
ckpt_file: ""
|
||||
file_name: ""
|
||||
file_format: ""
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
checkpoint_url: 'The location of checkpoint for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
load_path: 'The location of checkpoint for obs'
|
||||
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
num_classes: 'Class for dataset'
|
||||
batch_size: "Batch size for training and evaluation"
|
||||
epoch_size: "Total training epochs."
|
||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
|
@ -14,32 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Model, context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import create_dataset
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.config import config as cfg
|
||||
from src.utils import create_sliding_window, CalculateDice
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Test the UNet3D on images and target masks')
|
||||
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
|
||||
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
|
||||
parser.add_argument('--ckpt_path', dest='ckpt_path', type=str, default='', help='checkpoint path')
|
||||
return parser.parse_args()
|
||||
|
||||
def test_net(data_dir, seg_dir, ckpt_path, config=None):
|
||||
eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, is_training=False)
|
||||
@moxing_wrapper()
|
||||
def test_net(data_path, ckpt_path):
|
||||
data_dir = data_path + "/image/"
|
||||
seg_dir = data_path + "/seg/"
|
||||
eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, is_training=False)
|
||||
eval_data_size = eval_dataset.get_dataset_size()
|
||||
print("train dataset length is:", eval_data_size)
|
||||
|
||||
network = UNet3d(config=config)
|
||||
network = UNet3d()
|
||||
network.set_train(False)
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
@ -70,9 +66,5 @@ def test_net(data_dir, seg_dir, ckpt_path, config=None):
|
|||
print("eval average dice is {}".format(avg_dice))
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
print("Testing setting:", args)
|
||||
test_net(data_dir=args.data_url,
|
||||
seg_dir=args.seg_url,
|
||||
ckpt_path=args.ckpt_path,
|
||||
config=cfg)
|
||||
test_net(data_path=config.data_path,
|
||||
ckpt_path=config.checkpoint_file_path)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 3 ]
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]"
|
||||
exit 1
|
||||
|
@ -45,14 +45,6 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
PATH3=$(get_real_path $3)
|
||||
echo $PATH3
|
||||
if [ ! -d $PATH3 ]
|
||||
then
|
||||
echo "error: SEG_PATH=$PATH3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
@ -65,6 +57,7 @@ do
|
|||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp ../*.yaml ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
|
@ -73,8 +66,8 @@ do
|
|||
|
||||
python train.py \
|
||||
--run_distribute=True \
|
||||
--data_url=$PATH2 \
|
||||
--seg_url=$PATH3 > log.txt 2>&1 &
|
||||
--data_path=$PATH2 \
|
||||
--output_path './output' > log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
done
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
|
@ -23,7 +23,7 @@ then
|
|||
echo "=============================================================================================================="
|
||||
fi
|
||||
|
||||
if [ $# != 3 ]
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
|
@ -36,22 +36,14 @@ get_real_path(){
|
|||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
IMAGE_PATH=$(get_real_path $1)
|
||||
SEG_PATH=$(get_real_path $2)
|
||||
CHECKPOINT_FILE_PATH=$(get_real_path $3)
|
||||
echo $IMAGE_PATH
|
||||
echo $SEG_PATH
|
||||
PATH1=$(get_real_path $1)
|
||||
CHECKPOINT_FILE_PATH=$(get_real_path $2)
|
||||
echo $PATH1
|
||||
echo $CHECKPOINT_FILE_PATH
|
||||
|
||||
if [ ! -d $IMAGE_PATH ]
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$IMAGE_PATH is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $SEG_PATH ]
|
||||
then
|
||||
echo "error: SEG_PATH=$SEG_PATH is not a path"
|
||||
echo "error: PATH1=$PATH1 is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -74,9 +66,10 @@ fi
|
|||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
python eval.py --data_url=$IMAGE_PATH --seg_url=$SEG_PATH --ckpt_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 &
|
||||
python eval.py --data_path=$PATH1 --checkpoint_file_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 &
|
||||
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
cd ..
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 2 ]
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]"
|
||||
exit 1
|
||||
|
@ -36,14 +36,6 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: SEG_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
|
@ -54,9 +46,10 @@ rm -rf ./train
|
|||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp ../*.yaml ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_url=$PATH1 --seg_url=$PATH2 > train.log 2>&1 &
|
||||
python train.py --data_path=$PATH1 --output_path './output' > train.log 2>&1 &
|
||||
cd ..
|
||||
|
|
|
@ -17,7 +17,7 @@ import os
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
import SimpleITK as sitk
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_path", type=str, help="Input image directory to be processed.")
|
||||
|
|
|
@ -18,7 +18,7 @@ import glob
|
|||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.dataset.transforms.py_transforms import Compose
|
||||
from src.config import config as cfg
|
||||
from src.model_utils.config import config
|
||||
from src.transform import Dataset, ExpandChannel, LoadData, Orientation, ScaleIntensityRange, RandomCropSamples, OneHot
|
||||
|
||||
class ConvertLabel:
|
||||
|
@ -34,16 +34,16 @@ class ConvertLabel:
|
|||
Apply the transform to `img`, assuming `img` is channel-first and
|
||||
slicing doesn't apply to the channel dim.
|
||||
"""
|
||||
data[data > cfg['upper_limit']] = 0
|
||||
data = data - (cfg['lower_limit'] - 1)
|
||||
data = np.clip(data, 0, cfg['lower_limit'])
|
||||
data[data > config.upper_limit] = 0
|
||||
data = data - (config.lower_limit - 1)
|
||||
data = np.clip(data, 0, config.lower_limit)
|
||||
return data
|
||||
|
||||
def __call__(self, image, label):
|
||||
label = self.operation(label)
|
||||
return image, label
|
||||
|
||||
def create_dataset(data_path, seg_path, config, rank_size=1, rank_id=0, is_training=True):
|
||||
def create_dataset(data_path, seg_path, rank_size=1, rank_id=0, is_training=True):
|
||||
seg_files = sorted(glob.glob(os.path.join(seg_path, "*.nii.gz")))
|
||||
train_files = [os.path.join(data_path, os.path.basename(seg)) for seg in seg_files]
|
||||
train_ds = Dataset(data=train_files, seg=seg_files)
|
||||
|
|
|
@ -17,7 +17,7 @@ import mindspore.nn as nn
|
|||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
class SoftmaxCrossEntropyWithLogits(_Loss):
|
||||
def __init__(self):
|
||||
|
@ -27,11 +27,12 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
|||
self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
self.cast = P.Cast()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.num_classes = config.num_classes
|
||||
|
||||
def construct(self, logits, label):
|
||||
logits = self.transpose(logits, (0, 2, 3, 4, 1))
|
||||
label = self.transpose(label, (0, 2, 3, 4, 1))
|
||||
label = self.cast(label, mstype.float32)
|
||||
loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, config['num_classes'])), \
|
||||
self.reshape(label, (-1, config['num_classes']))))
|
||||
loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, self.num_classes)), \
|
||||
self.reshape(label, (-1, self.num_classes))))
|
||||
return self.get_loss(loss)
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
_config_path = "./default_config.yaml"
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
else:
|
||||
raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser, default, helper, path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -1,34 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from easydict import EasyDict
|
||||
config = EasyDict({
|
||||
'model': 'Unet3d',
|
||||
'lr': 0.0005,
|
||||
'epoch_size': 10,
|
||||
'batch_size': 1,
|
||||
'warmup_step': 120,
|
||||
'warmup_ratio': 0.3,
|
||||
'num_classes': 4,
|
||||
'in_channels': 1,
|
||||
'keep_checkpoint_max': 5,
|
||||
'loss_scale': 256.0,
|
||||
'roi_size': [224, 224, 96],
|
||||
'overlap': 0.25,
|
||||
'min_val': -500,
|
||||
'max_val': 1000,
|
||||
'upper_limit': 5,
|
||||
'lower_limit': 3,
|
||||
})
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -17,9 +17,10 @@ import mindspore.nn as nn
|
|||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from src.unet3d_parts import Down, Up
|
||||
from src.model_utils.config import config
|
||||
|
||||
class UNet3d(nn.Cell):
|
||||
def __init__(self, config=None):
|
||||
def __init__(self):
|
||||
super(UNet3d, self).__init__()
|
||||
self.n_channels = config.in_channels
|
||||
self.n_classes = config.num_classes
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import math
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
def correct_nifti_head(img):
|
||||
"""
|
||||
|
@ -131,11 +131,11 @@ def one_hot(labels):
|
|||
labels = np.reshape(labels, (N, -1))
|
||||
labels = labels.astype(np.int32)
|
||||
N, K = labels.shape
|
||||
one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32)
|
||||
one_hot_encoding = np.zeros((N, config.num_classes, K), dtype=np.float32)
|
||||
for i in range(N):
|
||||
for j in range(K):
|
||||
one_hot_encoding[i, labels[i][j], j] = 1
|
||||
labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W))
|
||||
labels = np.reshape(one_hot_encoding, (N, config.num_classes, D, H, W))
|
||||
return labels
|
||||
|
||||
def CalculateDice(y_pred, label):
|
||||
|
|
|
@ -14,43 +14,36 @@
|
|||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, Model, context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from src.dataset import create_dataset
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.config import config as cfg
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.loss import SoftmaxCrossEntropyWithLogits
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \
|
||||
device_id=device_id)
|
||||
mindspore.set_seed(1)
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Train the UNet3D on images and target masks')
|
||||
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
|
||||
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
|
||||
parser.add_argument('--run_distribute', dest='run_distribute', type=ast.literal_eval, default=False, \
|
||||
help='Run distribute, default: false')
|
||||
return parser.parse_args()
|
||||
|
||||
def train_net(data_dir,
|
||||
seg_dir,
|
||||
run_distribute,
|
||||
config=None):
|
||||
@moxing_wrapper()
|
||||
def train_net(data_path,
|
||||
run_distribute):
|
||||
data_dir = data_path + "/image/"
|
||||
seg_dir = data_path + "/seg/"
|
||||
if run_distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_device_id()
|
||||
rank_size = get_device_num()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||
device_num=rank_size,
|
||||
|
@ -58,12 +51,12 @@ def train_net(data_dir,
|
|||
else:
|
||||
rank_id = 0
|
||||
rank_size = 1
|
||||
train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, \
|
||||
train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, \
|
||||
rank_size=rank_size, rank_id=rank_id, is_training=True)
|
||||
train_data_size = train_dataset.get_dataset_size()
|
||||
print("train dataset length is:", train_data_size)
|
||||
|
||||
network = UNet3d(config=config)
|
||||
network = UNet3d()
|
||||
|
||||
loss = SoftmaxCrossEntropyWithLogits()
|
||||
lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32)
|
||||
|
@ -77,8 +70,9 @@ def train_net(data_dir,
|
|||
loss_cb = LossMonitor()
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(config.model),
|
||||
directory='./ckpt_{}/'.format(device_id),
|
||||
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='Unet3d',
|
||||
directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id),
|
||||
config=ckpt_config)
|
||||
callbacks_list = [loss_cb, time_cb, ckpoint_cb]
|
||||
print("============== Starting Training ==============")
|
||||
|
@ -86,9 +80,5 @@ def train_net(data_dir,
|
|||
print("============== End Training ==============")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
print("Training setting:", args)
|
||||
train_net(data_dir=args.data_url,
|
||||
seg_dir=args.seg_url,
|
||||
run_distribute=args.run_distribute,
|
||||
config=cfg)
|
||||
train_net(data_path=config.data_path,
|
||||
run_distribute=config.run_distribute)
|
||||
|
|
Loading…
Reference in New Issue