diff --git a/model_zoo/official/cv/unet3d/README.md b/model_zoo/official/cv/unet3d/README.md index a971608d8ee..98f805fb228 100644 --- a/model_zoo/official/cv/unet3d/README.md +++ b/model_zoo/official/cv/unet3d/README.md @@ -82,6 +82,37 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path ``` +If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows: + +```python +# run distributed training on modelarts example +# (1) First, Perform a or b. +# a. Set "enable_modelarts=True" on yaml file. +# Set other parameters on yaml file you need. +# b. Add "enable_modelarts=True" on the website UI interface. +# Add other parameters on the website UI interface. +# (2) Download nibabel and set pip-requirements.txt to code directory +# (3) Set the config directory to "config_path=/The path of config in S3/" +# (4) Set the code directory to "/path/unet" on the website UI interface. +# (5) Set the startup file to "train.py" on the website UI interface. +# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (7) Create your job. + +# run evaluation on modelarts example +# (1) Copy or upload your trained model to S3 bucket. +# (2) Perform a or b. +# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file. +# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file. +# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface. +# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface. +# (3) Download nibabel and set pip-requirements.txt to code directory +# (4) Set the config directory to "config_path=/The path of config in S3/" +# (5) Set the code directory to "/path/unet" on the website UI interface. +# (6) Set the startup file to "eval.py" on the website UI interface. +# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (8) Create your job. +``` + ## [Script Description](#contents) ### [Script and Sample Code](#contents) @@ -96,7 +127,6 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path │ ├──run_standalone_train.sh // shell script for standalone on Ascend │ ├──run_standalone_eval.sh // shell script for evaluation on Ascend ├── src - │ ├──config.py // parameter configuration │ ├──dataset.py // creating dataset │   ├──lr_schedule.py // learning rate scheduler │ ├──transform.py // handle dataset @@ -105,6 +135,12 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path │ ├──utils.py // General components (callback function) │ ├──unet3d_model.py // Unet3D model │ ├──unet3d_parts.py // Unet3D part + ├── model_utils + │ ├──config.py // parameter configuration + │ ├──device_adapter.py // device adapter + │ ├──local_adapter.py // local adapter + │ ├──moxing_adapter.py // moxing adapter + ├── default_config.yaml // parameter configuration ├── train.py // training script ├── eval.py // evaluation script diff --git a/model_zoo/official/cv/unet3d/default_config.yaml b/model_zoo/official/cv/unet3d/default_config.yaml new file mode 100644 index 00000000000..bcad19a6eca --- /dev/null +++ b/model_zoo/official/cv/unet3d/default_config.yaml @@ -0,0 +1,57 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +run_distribute: False +enable_profiling: False +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'Unet3d-9-877.ckpt' + +# ============================================================================== +# Training options +lr: 0.0005 +batch_size: 1 +epoch_size: 10 +warmup_step: 120 +warmup_ratio: 0.3 +num_classes: 4 +in_channels: 1 +keep_checkpoint_max: 1 +loss_scale: 256.0 +roi_size : [224, 224, 96] +overlap: 0.25 +min_val: -500 +max_val: 1000 +upper_limit: 5 +lower_limit: 3 + +# Export options +device_id: 0 +ckpt_file: "" +file_name: "" +file_format: "" + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +epoch_size: "Total training epochs." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +checkpoint_path: "The location of the checkpoint file." +checkpoint_file_path: "The location of the checkpoint file." diff --git a/model_zoo/official/cv/unet3d/eval.py b/model_zoo/official/cv/unet3d/eval.py index e0c8f8e0a99..e49ad5c77fb 100644 --- a/model_zoo/official/cv/unet3d/eval.py +++ b/model_zoo/official/cv/unet3d/eval.py @@ -14,32 +14,28 @@ # ============================================================================ import os -import argparse import numpy as np from mindspore import dtype as mstype from mindspore import Model, context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.dataset import create_dataset from src.unet3d_model import UNet3d -from src.config import config as cfg from src.utils import create_sliding_window, CalculateDice +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) -def get_args(): - parser = argparse.ArgumentParser(description='Test the UNet3D on images and target masks') - parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory') - parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory') - parser.add_argument('--ckpt_path', dest='ckpt_path', type=str, default='', help='checkpoint path') - return parser.parse_args() - -def test_net(data_dir, seg_dir, ckpt_path, config=None): - eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, is_training=False) +@moxing_wrapper() +def test_net(data_path, ckpt_path): + data_dir = data_path + "/image/" + seg_dir = data_path + "/seg/" + eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, is_training=False) eval_data_size = eval_dataset.get_dataset_size() print("train dataset length is:", eval_data_size) - network = UNet3d(config=config) + network = UNet3d() network.set_train(False) param_dict = load_checkpoint(ckpt_path) load_param_into_net(network, param_dict) @@ -70,9 +66,5 @@ def test_net(data_dir, seg_dir, ckpt_path, config=None): print("eval average dice is {}".format(avg_dice)) if __name__ == '__main__': - args = get_args() - print("Testing setting:", args) - test_net(data_dir=args.data_url, - seg_dir=args.seg_url, - ckpt_path=args.ckpt_path, - config=cfg) + test_net(data_path=config.data_path, + ckpt_path=config.checkpoint_file_path) diff --git a/model_zoo/official/cv/unet3d/scripts/run_distribute_train.sh b/model_zoo/official/cv/unet3d/scripts/run_distribute_train.sh index 3a1d7f3ecc4..d3a886a2344 100644 --- a/model_zoo/official/cv/unet3d/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/unet3d/scripts/run_distribute_train.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================ -if [ $# -ne 3 ] +if [ $# -ne 2 ] then echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]" exit 1 @@ -45,14 +45,6 @@ then exit 1 fi -PATH3=$(get_real_path $3) -echo $PATH3 -if [ ! -d $PATH3 ] -then - echo "error: SEG_PATH=$PATH3 is not a file" -exit 1 -fi - ulimit -u unlimited export DEVICE_NUM=8 export RANK_SIZE=8 @@ -65,6 +57,7 @@ do rm -rf ./train_parallel$i mkdir ./train_parallel$i cp ../*.py ./train_parallel$i + cp ../*.yaml ./train_parallel$i cp *.sh ./train_parallel$i cp -r ../src ./train_parallel$i cd ./train_parallel$i || exit @@ -73,8 +66,8 @@ do python train.py \ --run_distribute=True \ - --data_url=$PATH2 \ - --seg_url=$PATH3 > log.txt 2>&1 & + --data_path=$PATH2 \ + --output_path './output' > log.txt 2>&1 & cd ../ done diff --git a/model_zoo/official/cv/unet3d/scripts/run_standalone_eval.sh b/model_zoo/official/cv/unet3d/scripts/run_standalone_eval.sh index aaf83fe1df1..f377be58ddd 100644 --- a/model_zoo/official/cv/unet3d/scripts/run_standalone_eval.sh +++ b/model_zoo/official/cv/unet3d/scripts/run_standalone_eval.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================ -if [ $# != 3 ] +if [ $# != 2 ] then echo "==============================================================================================================" echo "Please run the script as: " @@ -23,7 +23,7 @@ then echo "==============================================================================================================" fi -if [ $# != 3 ] +if [ $# != 2 ] then echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]" exit 1 @@ -36,22 +36,14 @@ get_real_path(){ echo "$(realpath -m $PWD/$1)" fi } -IMAGE_PATH=$(get_real_path $1) -SEG_PATH=$(get_real_path $2) -CHECKPOINT_FILE_PATH=$(get_real_path $3) -echo $IMAGE_PATH -echo $SEG_PATH +PATH1=$(get_real_path $1) +CHECKPOINT_FILE_PATH=$(get_real_path $2) +echo $PATH1 echo $CHECKPOINT_FILE_PATH -if [ ! -d $IMAGE_PATH ] +if [ ! -d $PATH1 ] then - echo "error: IMAGE_PATH=$IMAGE_PATH is not a path" -exit 1 -fi - -if [ ! -d $SEG_PATH ] -then - echo "error: SEG_PATH=$SEG_PATH is not a path" + echo "error: PATH1=$PATH1 is not a path" exit 1 fi @@ -74,9 +66,10 @@ fi mkdir ./eval cp ../*.py ./eval cp *.sh ./eval +cp ../*.yaml ./eval cp -r ../src ./eval cd ./eval || exit echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" -python eval.py --data_url=$IMAGE_PATH --seg_url=$SEG_PATH --ckpt_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 & +python eval.py --data_path=$PATH1 --checkpoint_file_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 & echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" cd .. diff --git a/model_zoo/official/cv/unet3d/scripts/run_standalone_train.sh b/model_zoo/official/cv/unet3d/scripts/run_standalone_train.sh index f3967a5da2d..1e30cea5c7c 100644 --- a/model_zoo/official/cv/unet3d/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/unet3d/scripts/run_standalone_train.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================ -if [ $# -ne 2 ] +if [ $# -ne 1 ] then echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]" exit 1 @@ -36,14 +36,6 @@ then exit 1 fi -PATH2=$(get_real_path $2) -echo $PATH2 -if [ ! -d $PATH2 ] -then - echo "error: SEG_PATH=$PATH2 is not a file" -exit 1 -fi - ulimit -u unlimited export DEVICE_NUM=1 export DEVICE_ID=0 @@ -54,9 +46,10 @@ rm -rf ./train mkdir ./train cp ../*.py ./train cp *.sh ./train +cp ../*.yaml ./train cp -r ../src ./train cd ./train || exit echo "start training for device $DEVICE_ID" env > env.log -python train.py --data_url=$PATH1 --seg_url=$PATH2 > train.log 2>&1 & +python train.py --data_path=$PATH1 --output_path './output' > train.log 2>&1 & cd .. diff --git a/model_zoo/official/cv/unet3d/src/convert_nifti.py b/model_zoo/official/cv/unet3d/src/convert_nifti.py index d00544ddb45..0e6622007df 100644 --- a/model_zoo/official/cv/unet3d/src/convert_nifti.py +++ b/model_zoo/official/cv/unet3d/src/convert_nifti.py @@ -17,7 +17,7 @@ import os import argparse from pathlib import Path import SimpleITK as sitk -from src.config import config +from src.model_utils.config import config parser = argparse.ArgumentParser() parser.add_argument("--input_path", type=str, help="Input image directory to be processed.") diff --git a/model_zoo/official/cv/unet3d/src/dataset.py b/model_zoo/official/cv/unet3d/src/dataset.py index 98e162cef06..b3b828e3c04 100644 --- a/model_zoo/official/cv/unet3d/src/dataset.py +++ b/model_zoo/official/cv/unet3d/src/dataset.py @@ -18,7 +18,7 @@ import glob import numpy as np import mindspore.dataset as ds from mindspore.dataset.transforms.py_transforms import Compose -from src.config import config as cfg +from src.model_utils.config import config from src.transform import Dataset, ExpandChannel, LoadData, Orientation, ScaleIntensityRange, RandomCropSamples, OneHot class ConvertLabel: @@ -34,16 +34,16 @@ class ConvertLabel: Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - data[data > cfg['upper_limit']] = 0 - data = data - (cfg['lower_limit'] - 1) - data = np.clip(data, 0, cfg['lower_limit']) + data[data > config.upper_limit] = 0 + data = data - (config.lower_limit - 1) + data = np.clip(data, 0, config.lower_limit) return data def __call__(self, image, label): label = self.operation(label) return image, label -def create_dataset(data_path, seg_path, config, rank_size=1, rank_id=0, is_training=True): +def create_dataset(data_path, seg_path, rank_size=1, rank_id=0, is_training=True): seg_files = sorted(glob.glob(os.path.join(seg_path, "*.nii.gz"))) train_files = [os.path.join(data_path, os.path.basename(seg)) for seg in seg_files] train_ds = Dataset(data=train_files, seg=seg_files) diff --git a/model_zoo/official/cv/unet3d/src/loss.py b/model_zoo/official/cv/unet3d/src/loss.py index 1106b0cd2f9..45d65deae4e 100644 --- a/model_zoo/official/cv/unet3d/src/loss.py +++ b/model_zoo/official/cv/unet3d/src/loss.py @@ -17,7 +17,7 @@ import mindspore.nn as nn from mindspore import dtype as mstype from mindspore.ops import operations as P from mindspore.nn.loss.loss import _Loss -from src.config import config +from src.model_utils.config import config class SoftmaxCrossEntropyWithLogits(_Loss): def __init__(self): @@ -27,11 +27,12 @@ class SoftmaxCrossEntropyWithLogits(_Loss): self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) self.cast = P.Cast() self.reduce_mean = P.ReduceMean() + self.num_classes = config.num_classes def construct(self, logits, label): logits = self.transpose(logits, (0, 2, 3, 4, 1)) label = self.transpose(label, (0, 2, 3, 4, 1)) label = self.cast(label, mstype.float32) - loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, config['num_classes'])), \ - self.reshape(label, (-1, config['num_classes'])))) + loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, self.num_classes)), \ + self.reshape(label, (-1, self.num_classes)))) return self.get_loss(loss) diff --git a/model_zoo/official/cv/unet3d/src/model_utils/config.py b/model_zoo/official/cv/unet3d/src/model_utils/config.py new file mode 100644 index 00000000000..92136db1e0c --- /dev/null +++ b/model_zoo/official/cv/unet3d/src/model_utils/config.py @@ -0,0 +1,125 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse arguments""" + +import os +import ast +import argparse +from pprint import pprint, pformat +import yaml + +_config_path = "./default_config.yaml" + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, + help=help_description) + else: + parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + else: + raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml") + print(cfg_helper) + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"), + help="Config file path") + path_args, _ = parser.parse_known_args() + default, helper = parse_yaml(path_args.config_path) + pprint(default) + args = parse_cli_to_yaml(parser, default, helper, path_args.config_path) + final_config = merge(args, default) + return Config(final_config) + +config = get_config() diff --git a/model_zoo/official/cv/unet3d/src/config.py b/model_zoo/official/cv/unet3d/src/model_utils/device_adapter.py similarity index 53% rename from model_zoo/official/cv/unet3d/src/config.py rename to model_zoo/official/cv/unet3d/src/model_utils/device_adapter.py index 70ab7290781..9c3d21d5e47 100644 --- a/model_zoo/official/cv/unet3d/src/config.py +++ b/model_zoo/official/cv/unet3d/src/model_utils/device_adapter.py @@ -1,34 +1,27 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# less required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -from easydict import EasyDict -config = EasyDict({ - 'model': 'Unet3d', - 'lr': 0.0005, - 'epoch_size': 10, - 'batch_size': 1, - 'warmup_step': 120, - 'warmup_ratio': 0.3, - 'num_classes': 4, - 'in_channels': 1, - 'keep_checkpoint_max': 5, - 'loss_scale': 256.0, - 'roi_size': [224, 224, 96], - 'overlap': 0.25, - 'min_val': -500, - 'max_val': 1000, - 'upper_limit': 5, - 'lower_limit': 3, -}) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Device adapter for ModelArts""" + +from src.model_utils.config import config + +if config.enable_modelarts: + from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + "get_device_id", "get_device_num", "get_rank_id", "get_job_id" +] diff --git a/model_zoo/official/cv/unet3d/src/model_utils/local_adapter.py b/model_zoo/official/cv/unet3d/src/model_utils/local_adapter.py new file mode 100644 index 00000000000..769fa6dc78e --- /dev/null +++ b/model_zoo/official/cv/unet3d/src/model_utils/local_adapter.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Local adapter""" + +import os + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/model_zoo/official/cv/unet3d/src/model_utils/moxing_adapter.py b/model_zoo/official/cv/unet3d/src/model_utils/moxing_adapter.py new file mode 100644 index 00000000000..aabd5ac6cf1 --- /dev/null +++ b/model_zoo/official/cv/unet3d/src/model_utils/moxing_adapter.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Moxing adapter for ModelArts""" + +import os +import functools +from mindspore import context +from src.model_utils.config import config + +_global_sync_count = 0 + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs. + """ + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + run_func(*args, **kwargs) + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + return wrapped_func + return wrapper diff --git a/model_zoo/official/cv/unet3d/src/unet3d_model.py b/model_zoo/official/cv/unet3d/src/unet3d_model.py index ba21f50331c..85bae388e7a 100644 --- a/model_zoo/official/cv/unet3d/src/unet3d_model.py +++ b/model_zoo/official/cv/unet3d/src/unet3d_model.py @@ -17,9 +17,10 @@ import mindspore.nn as nn from mindspore import dtype as mstype from mindspore.ops import operations as P from src.unet3d_parts import Down, Up +from src.model_utils.config import config class UNet3d(nn.Cell): - def __init__(self, config=None): + def __init__(self): super(UNet3d, self).__init__() self.n_channels = config.in_channels self.n_classes = config.num_classes diff --git a/model_zoo/official/cv/unet3d/src/utils.py b/model_zoo/official/cv/unet3d/src/utils.py index efa78971dac..dbbffa6c1bb 100644 --- a/model_zoo/official/cv/unet3d/src/utils.py +++ b/model_zoo/official/cv/unet3d/src/utils.py @@ -15,7 +15,7 @@ import math import numpy as np -from src.config import config +from src.model_utils.config import config def correct_nifti_head(img): """ @@ -131,11 +131,11 @@ def one_hot(labels): labels = np.reshape(labels, (N, -1)) labels = labels.astype(np.int32) N, K = labels.shape - one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32) + one_hot_encoding = np.zeros((N, config.num_classes, K), dtype=np.float32) for i in range(N): for j in range(K): one_hot_encoding[i, labels[i][j], j] = 1 - labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W)) + labels = np.reshape(one_hot_encoding, (N, config.num_classes, D, H, W)) return labels def CalculateDice(y_pred, label): diff --git a/model_zoo/official/cv/unet3d/train.py b/model_zoo/official/cv/unet3d/train.py index 6729e884a11..eb85dabcf88 100644 --- a/model_zoo/official/cv/unet3d/train.py +++ b/model_zoo/official/cv/unet3d/train.py @@ -14,43 +14,36 @@ # ============================================================================ import os -import argparse -import ast import mindspore import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore import Tensor, Model, context from mindspore.context import ParallelMode -from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.communication.management import init from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor from src.dataset import create_dataset from src.unet3d_model import UNet3d -from src.config import config as cfg from src.lr_schedule import dynamic_lr from src.loss import SoftmaxCrossEntropyWithLogits +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.device_adapter import get_device_id, get_device_num device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \ device_id=device_id) mindspore.set_seed(1) -def get_args(): - parser = argparse.ArgumentParser(description='Train the UNet3D on images and target masks') - parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory') - parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory') - parser.add_argument('--run_distribute', dest='run_distribute', type=ast.literal_eval, default=False, \ - help='Run distribute, default: false') - return parser.parse_args() - -def train_net(data_dir, - seg_dir, - run_distribute, - config=None): +@moxing_wrapper() +def train_net(data_path, + run_distribute): + data_dir = data_path + "/image/" + seg_dir = data_path + "/seg/" if run_distribute: init() - rank_id = get_rank() - rank_size = get_group_size() + rank_id = get_device_id() + rank_size = get_device_num() parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, @@ -58,12 +51,12 @@ def train_net(data_dir, else: rank_id = 0 rank_size = 1 - train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, \ + train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, \ rank_size=rank_size, rank_id=rank_id, is_training=True) train_data_size = train_dataset.get_dataset_size() print("train dataset length is:", train_data_size) - network = UNet3d(config=config) + network = UNet3d() loss = SoftmaxCrossEntropyWithLogits() lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32) @@ -77,8 +70,9 @@ def train_net(data_dir, loss_cb = LossMonitor() ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, keep_checkpoint_max=config.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix='{}'.format(config.model), - directory='./ckpt_{}/'.format(device_id), + ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) + ckpoint_cb = ModelCheckpoint(prefix='Unet3d', + directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id), config=ckpt_config) callbacks_list = [loss_cb, time_cb, ckpoint_cb] print("============== Starting Training ==============") @@ -86,9 +80,5 @@ def train_net(data_dir, print("============== End Training ==============") if __name__ == '__main__': - args = get_args() - print("Training setting:", args) - train_net(data_dir=args.data_url, - seg_dir=args.seg_url, - run_distribute=args.run_distribute, - config=cfg) + train_net(data_path=config.data_path, + run_distribute=config.run_distribute)