diff --git a/model_zoo/official/cv/unet/README.md b/model_zoo/official/cv/unet/README.md index 58ccbdfbf5d..fdd0419c899 100644 --- a/model_zoo/official/cv/unet/README.md +++ b/model_zoo/official/cv/unet/README.md @@ -160,6 +160,35 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR] Then you can run everything just like on ascend. +If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows: + +```python +# run distributed training on modelarts example +# (1) First, Perform a or b. +# a. Set "enable_modelarts=True" on yaml file. +# Set other parameters on yaml file you need. +# b. Add "enable_modelarts=True" on the website UI interface. +# Add other parameters on the website UI interface. +# (2) Set the config directory to "config_path=/The path of config in S3/" +# (3) Set the code directory to "/path/unet" on the website UI interface. +# (4) Set the startup file to "train.py" on the website UI interface. +# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (6) Create your job. + +# run evaluation on modelarts example +# (1) Copy or upload your trained model to S3 bucket. +# (2) Perform a or b. +# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file. +# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file. +# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface. +# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface. +# (3) Set the config directory to "config_path=/The path of config in S3/" +# (4) Set the code directory to "/path/unet" on the website UI interface. +# (5) Set the startup file to "eval.py" on the website UI interface. +# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (7) Create your job. +``` + ## [Script Description](#contents) ### [Script and Sample Code](#contents) @@ -190,6 +219,16 @@ Then you can run everything just like on ascend. ├──__init__.py // init file ├──unet_model.py // unet model ├──unet_parts.py // unet part + ├── model_utils + │ ├── config.py // parameter configuration + │ ├── device_adapter.py // device adapter + │ ├── local_adapter.py // local adapter + │ ├── moxing_adapter.py // moxing adapter + ├── unet_medical_config.yaml // parameter configuration + ├── unet_nested_cell_config.yaml // parameter configuration + ├── unet_nested_config.yaml // parameter configuration + ├── unet_simple_config.yaml // parameter configuration + ├── unet_simple_coco_config.yaml // parameter configuration ├── train.py // training script ├── eval.py // evaluation script ├── export.py // export script diff --git a/model_zoo/official/cv/unet/README_CN.md b/model_zoo/official/cv/unet/README_CN.md index 036852b4b2e..4708844e5e3 100644 --- a/model_zoo/official/cv/unet/README_CN.md +++ b/model_zoo/official/cv/unet/README_CN.md @@ -164,6 +164,38 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR] 然后在容器里的操作就和Ascend平台上是一样的。 +如果要在modelarts上进行模型的训练,可以参考modelarts的官方指导文档(https://support.huaweicloud.com/modelarts/) +开始进行模型的训练和推理,具体操作如下: + +```python +# 在modelarts上使用分布式训练的示例: +# (1) 选址a或者b其中一种方式。 +# a. 设置 "enable_modelarts=True" 。 +# 在yaml文件上设置网络所需的参数。 +# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。 +# 在modelarts的界面上设置网络所需的参数。 +# (2)设置网络配置文件的路径 "config_path=/The path of config in S3/" +# (3) 在modelarts的界面上设置代码的路径 "/path/unet"。 +# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。 +# (5) 在modelarts的界面上设置模型的数据路径 "Dataset path" , +# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。 +# (6) 开始模型的训练。 + +# 在modelarts上使用模型推理的示例 +# (1) 把训练好的模型地方到桶的对应位置。 +# (2) 选址a或者b其中一种方式。 +# a. 设置 "checkpoint_file_path='/cache/checkpoint_path/model.ckpt" 在 yaml 文件. +# 设置 "checkpoint_url=/The path of checkpoint in S3/" 在 yaml 文件. +# b. 增加 "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" 参数在modearts的界面上。 +# 增加 "checkpoint_url=/The path of checkpoint in S3/" 参数在modearts的界面上。 +# (3) 设置网络配置文件的路径 "config_path=/The path of config in S3/" +# (4) 在modelarts的界面上设置代码的路径 "/path/unet"。 +# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。 +# (6) 在modelarts的界面上设置模型的数据路径 "Dataset path" , +# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。 +# (7) 开始模型的推理。 +``` + ## 脚本说明 ### 脚本及样例代码 @@ -194,6 +226,16 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR] ├──__init__.py ├──unet_model.py // Unet++ 网络结构 ├──unet_parts.py // Unet++ 子网 + ├── model_utils + │ ├── config.py // 参数配置 + │ ├── device_adapter.py // 设备配置 + │ ├── local_adapter.py // 本地设备配置 + │ ├── moxing_adapter.py // modelarts设备配置 + ├── unet_medical_config.yaml // 配置文件 + ├── unet_nested_cell_config.yaml // 配置文件 + ├── unet_nested_config.yaml // 配置文件 + ├── unet_simple_config.yaml // 配置文件 + ├── unet_simple_coco_config.yaml // 配置文件 ├── train.py // 训练脚本 ├── eval.py // 推理脚本 ├── export.py // 导出脚本 diff --git a/model_zoo/official/cv/unet/eval.py b/model_zoo/official/cv/unet/eval.py index 4cdc8ba6977..42ba491bd0f 100644 --- a/model_zoo/official/cv/unet/eval.py +++ b/model_zoo/official/cv/unet/eval.py @@ -14,7 +14,6 @@ # ============================================================================ import os -import argparse import logging from mindspore import context, Model from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -22,38 +21,39 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.data_loader import create_dataset, create_multi_class_dataset from src.unet_medical import UNetMedical from src.unet_nested import NestedUNet, UNet -from src.config import cfg_unet from src.utils import UnetEval, TempLoss, dice_coeff +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper -device_id = int(os.getenv('DEVICE_ID')) +device_id = int(os.getenv("DEVICE_ID")) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) +@moxing_wrapper() def test_net(data_dir, ckpt_path, - cross_valid_ind=1, - cfg=None): - if cfg['model'] == 'unet_medical': - net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) - elif cfg['model'] == 'unet_nested': - net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], - use_bn=cfg['use_bn'], use_ds=False) - elif cfg['model'] == 'unet_simple': - net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + cross_valid_ind=1): + if config.model_name == 'unet_medical': + net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes) + elif config.model_name == 'unet_nested': + net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv, + use_bn=config.use_bn, use_ds=False) + elif config.model_name == 'unet_simple': + net = UNet(in_channel=config.num_channels, n_class=config.num_classes) else: - raise ValueError("Unsupported model: {}".format(cfg['model'])) + raise ValueError("Unsupported model: {}".format(config.model_name)) param_dict = load_checkpoint(ckpt_path) load_param_into_net(net, param_dict) net = UnetEval(net) - if 'dataset' in cfg and cfg['dataset'] != "ISBI": - split = cfg['split'] if 'split' in cfg else 0.8 - valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1, - num_classes=cfg['num_classes'], is_train=False, - eval_resize=cfg["eval_resize"], split=split, + if hasattr(config, "dataset") and config.dataset != "ISBI": + split = config.split if hasattr(config, "dataset") else 0.8 + valid_dataset = create_multi_class_dataset(data_dir, config.image_size, 1, 1, + num_classes=config.num_classes, is_train=False, + eval_resize=config.eval_resize, split=split, python_multiprocessing=False, shuffle=False) else: _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, - do_crop=cfg['crop'], img_size=cfg['img_size']) - model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(cfg_unet)}) + do_crop=config.crop, img_size=config.image_size) + model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()}) print("============== Starting Evaluating ============") eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] @@ -61,22 +61,8 @@ def test_net(data_dir, print("============== Cross valid IOU is:", eval_score[1]) -def get_args(): - parser = argparse.ArgumentParser(description='Test the UNet on images and target masks', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', - help='data directory') - parser.add_argument('-p', '--ckpt_path', dest='ckpt_path', type=str, default='ckpt_unet_medical_adam-1_600.ckpt', - help='checkpoint path') - - return parser.parse_args() - - if __name__ == '__main__': logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') - args = get_args() - print("Testing setting:", args) - test_net(data_dir=args.data_url, - ckpt_path=args.ckpt_path, - cross_valid_ind=cfg_unet['cross_valid_ind'], - cfg=cfg_unet) + test_net(data_dir=config.data_path, + ckpt_path=config.checkpoint_file_path, + cross_valid_ind=config.cross_valid_ind) diff --git a/model_zoo/official/cv/unet/export.py b/model_zoo/official/cv/unet/export.py index b4af88f316c..9ced1d3c04e 100644 --- a/model_zoo/official/cv/unet/export.py +++ b/model_zoo/official/cv/unet/export.py @@ -13,46 +13,36 @@ # limitations under the License. # ============================================================================ -import argparse import numpy as np from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context from src.unet_medical.unet_model import UNetMedical from src.unet_nested import NestedUNet, UNet -from src.config import cfg_unet as cfg from src.utils import UnetEval +from src.model_utils.config import config +from src.model_utils.device_adapter import get_device_id -parser = argparse.ArgumentParser(description='unet export') -parser.add_argument("--device_id", type=int, default=0, help="Device id") -parser.add_argument("--batch_size", type=int, default=1, help="batch size") -parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") -parser.add_argument('--width', type=int, default=572, help='input width') -parser.add_argument('--height', type=int, default=572, help='input height') -parser.add_argument("--file_name", type=str, default="unet", help="output file name.") -parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') -parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", - help="device target") -args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) -if args.device_target == "Ascend": - context.set_context(device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) +if config.device_target == "Ascend": + context.set_context(device_id=get_device_id()) if __name__ == "__main__": - if cfg['model'] == 'unet_medical': - net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) - elif cfg['model'] == 'unet_nested': - net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], - use_bn=cfg['use_bn'], use_ds=False) - elif cfg['model'] == 'unet_simple': - net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + if config.model == 'unet_medical': + net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes) + elif config.model == 'unet_nested': + net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv, + use_bn=config.use_bn, use_ds=False) + elif config.model == 'unet_simple': + net = UNet(in_channel=config.num_channels, n_class=config.num_classes) else: - raise ValueError("Unsupported model: {}".format(cfg['model'])) + raise ValueError("Unsupported model: {}".format(config.model)) # return a parameter dict for model - param_dict = load_checkpoint(args.ckpt_file) + param_dict = load_checkpoint(config.checkpoint_file_path) # load the parameter into net load_param_into_net(net, param_dict) net = UnetEval(net) - input_data = Tensor(np.ones([args.batch_size, cfg["num_channels"], args.height, args.width]).astype(np.float32)) - export(net, input_data, file_name=args.file_name, file_format=args.file_format) + input_data = Tensor(np.ones([config.batch_size, config.num_channels, config.height, \ + config.width]).astype(np.float32)) + export(net, input_data, file_name=config.file_name, file_format=config.file_format) diff --git a/model_zoo/official/cv/unet/postprocess.py b/model_zoo/official/cv/unet/postprocess.py index f94b970cdce..45b481630b2 100644 --- a/model_zoo/official/cv/unet/postprocess.py +++ b/model_zoo/official/cv/unet/postprocess.py @@ -14,11 +14,10 @@ # ============================================================================ """unet 310 infer.""" import os -import argparse import cv2 import numpy as np -from src.config import cfg_unet +from src.model_utils.config import config class dice_coeff(): def __init__(self): @@ -38,20 +37,20 @@ class dice_coeff(): if b != 1: raise ValueError('Batch size should be 1 when in evaluation.') y = y.reshape((h, w, c)) - if cfg_unet["eval_activate"].lower() == "softmax": + if config.eval_activate.lower() == "softmax": y_softmax = np.squeeze(inputs[0][0], axis=0) - if cfg_unet["eval_resize"]: + if config.eval_resize: y_pred = [] - for m in range(cfg_unet["num_classes"]): + for m in range(config.num_classes): y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255) y_pred = np.stack(y_pred, axis=-1) else: y_pred = y_softmax - elif cfg_unet["eval_activate"].lower() == "argmax": + elif config.eval_activate.lower() == "argmax": y_argmax = np.squeeze(inputs[0][1], axis=0) y_pred = [] - for n in range(cfg_unet["num_classes"]): - if cfg_unet["eval_resize"]: + for n in range(config.num_classes): + if config.eval_resize: y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST)) else: y_pred.append(np.float32(y_argmax == n)) @@ -73,25 +72,13 @@ class dice_coeff(): raise RuntimeError('Total samples num must not be 0.') return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) -def get_args(): - parser = argparse.ArgumentParser(description='Test the UNet on images and target masks', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', - help='data directory') - parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/', - help='infer result path') - - return parser.parse_args() - if __name__ == '__main__': - args = get_args() - - rst_path = args.rst_path + rst_path = config.rst_path metrics = dice_coeff() - if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": - img_size = tuple(cfg_unet['img_size']) + if config.dataset == "Cell_nuclei": + img_size = tuple(config.img_size) for i, bin_name in enumerate(os.listdir('./preprocess_Result/')): f = bin_name.replace(".png", "") bin_name_softmax = f + "_0.bin" @@ -100,7 +87,7 @@ if __name__ == '__main__': file_name_arg = rst_path + bin_name_argmax softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2) argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96) - mask = cv2.imread(os.path.join(args.data_url, f, "mask.png"), cv2.IMREAD_GRAYSCALE) + mask = cv2.imread(os.path.join(config.data_path, f, "mask.png"), cv2.IMREAD_GRAYSCALE) mask = cv2.resize(mask, img_size) mask = mask.astype(np.float32) / 255 mask = (mask > 0.5).astype(np.int) diff --git a/model_zoo/official/cv/unet/preprocess.py b/model_zoo/official/cv/unet/preprocess.py index 5b6015eb906..6915939adee 100644 --- a/model_zoo/official/cv/unet/preprocess.py +++ b/model_zoo/official/cv/unet/preprocess.py @@ -13,19 +13,18 @@ # limitations under the License. # ============================================================================ """unet 310 infer preprocess dataset""" -import argparse import os import numpy as np import cv2 from src.data_loader import create_dataset -from src.config import cfg_unet +from src.model_utils.config import config -def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None): +def preprocess_dataset(data_dir, result_path, cross_valid_ind=1): - _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], - img_size=cfg['img_size']) + _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=config.crop, + img_size=config.img_size) labels_list = [] for i, data in enumerate(valid_dataset): @@ -87,21 +86,9 @@ class CellNucleiDataset: return len(self.val_ids) -def get_args(): - parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', - help='data directory') - parser.add_argument('-p', '--result_path', dest='result_path', type=str, default='./preprocess_Result/', - help='result path') - return parser.parse_args() - - if __name__ == '__main__': - args = get_args() - - if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": - cell_dataset = CellNucleiDataset(args.data_url, 1, args.result_path, False, 0.8) + if config.dataset == "Cell_nuclei": + cell_dataset = CellNucleiDataset(config.data_path, 1, config.result_path, False, 0.8) else: - preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, - result_path=args.result_path) + preprocess_dataset(data_dir=config.data_path, cross_valid_ind=config.cross_valid_ind, + result_path=config.result_path) diff --git a/model_zoo/official/cv/unet/preprocess_dataset.py b/model_zoo/official/cv/unet/preprocess_dataset.py index 550ac8e6918..494b348a83f 100644 --- a/model_zoo/official/cv/unet/preprocess_dataset.py +++ b/model_zoo/official/cv/unet/preprocess_dataset.py @@ -17,10 +17,9 @@ Preprocess dataset. Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`. """ import os -import argparse import cv2 import numpy as np -from model_zoo.official.cv.unet.src.config import cfg_unet +from model_zoo.official.cv.unet.src.model_utils.config import config def annToMask(ann, height, width): """Convert annotation to RLE and then to binary mask.""" @@ -107,32 +106,27 @@ def preprocess_coco_dataset(param_dict): cv2.imwrite(os.path.join(save_dir, img_name, "image.png"), img) cv2.imwrite(os.path.join(save_dir, img_name, "mask.png"), mask) -def preprocess_dataset(cfg, data_dir): +def preprocess_dataset(data_dir): """Select preprocess function.""" - if cfg['dataset'].lower() == "cell_nuclei": + if config.dataset.lower() == "cell_nuclei": preprocess_cell_nuclei_dataset({"data_dir": data_dir}) - elif cfg['dataset'].lower() == "coco": - if 'split' in cfg and cfg['split'] == 1.0: + elif config.dataset.lower() == "coco": + if config.split == 1.0: train_data_path = os.path.join(data_dir, "train") val_data_path = os.path.join(data_dir, "val") - train_param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"], - "coco_dir": cfg["coco_dir"], "save_dir": train_data_path} + train_param_dict = {"anno_json": config.anno_json, "coco_classes": config.coco_classes, + "coco_dir": config.coco_dir, "save_dir": train_data_path} preprocess_coco_dataset(train_param_dict) - val_param_dict = {"anno_json": cfg["val_anno_json"], "coco_classes": cfg["coco_classes"], - "coco_dir": cfg["val_coco_dir"], "save_dir": val_data_path} + val_param_dict = {"anno_json": config.val_anno_json, "coco_classes": config.coco_classes, + "coco_dir": config.val_coco_dir, "save_dir": val_data_path} preprocess_coco_dataset(val_param_dict) else: - param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"], - "coco_dir": cfg["coco_dir"], "save_dir": data_dir} + param_dict = {"anno_json": config.anno_json, "coco_classes": config.coco_classes, + "coco_dir": config.coco_dir, "save_dir": data_dir} preprocess_coco_dataset(param_dict) else: - raise ValueError("Not support dataset mode {}".format(cfg['dataset'])) + raise ValueError("Not support dataset mode {}".format(config.dataset)) print("========== end preprocess dataset ==========") if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', - help='save data directory') - args = parser.parse_args() - preprocess_dataset(cfg_unet, args.data_url) + preprocess_dataset(config.data_path) diff --git a/model_zoo/official/cv/unet/scripts/run_distribute_train.sh b/model_zoo/official/cv/unet/scripts/run_distribute_train.sh index e7749935065..0f23873f13c 100644 --- a/model_zoo/official/cv/unet/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/unet/scripts/run_distribute_train.sh @@ -22,13 +22,13 @@ get_real_path() { fi } -if [ $# != 2 ] +if [ $# != 3 ] then echo "==============================================================================================================" - echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" + echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]" echo "Please run the script as: " - echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" - echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data" + echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]" + echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data /absolute/path/to/config" echo "==============================================================================================================" exit 1 fi @@ -36,6 +36,7 @@ PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) export HCCL_CONNECT_TIMEOUT=600 export RANK_SIZE=8 DATASET=$(get_real_path $2) +CONFIG_PATH=$(get_real_path $3) export RANK_TABLE_FILE=$(get_real_path $1) for((i=0;i log.txt 2>&1 & + --data_path=$DATASET \ + --config_path=$CONFIG_PATH \ + --output_path './output' > log.txt 2>&1 & cd ../ done diff --git a/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh b/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh index c91e66c3c37..1e919f09be9 100644 --- a/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh +++ b/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh @@ -22,23 +22,24 @@ get_real_path() { fi } -if [ $# != 2 ] && [ $# != 3 ] +if [ $# != 3 ] && [ $# != 4 ] then echo "==============================================================================================================" echo "Please run the script as: " - echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [DEVICE_ID](option, default is 0)" - echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ 0" + echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [CONFIG_PATH] [DEVICE_ID](option, default is 0)" + echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ /path/to/config/ 0" echo "==============================================================================================================" exit 1 fi PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) export DEVICE_ID=0 -if [ $# != 2 ] +if [ $# != 3 ] then - export DEVICE_ID=$3 + export DEVICE_ID=$4 fi DATASET=$(get_real_path $1) CHECKPOINT=$(get_real_path $2) +CONFIG_PATH=$(get_real_path $3) echo "========== start run evaluation ===========" echo "please get log at eval.log" -python ${PROJECT_DIR}/../eval.py --data_url=$DATASET --ckpt_path=$CHECKPOINT > eval.log 2>&1 & +python ${PROJECT_DIR}/../eval.py --data_path=$DATASET --checkpoint_file_path=$CHECKPOINT --config_path=$CONFIG_PATH > eval.log 2>&1 & diff --git a/model_zoo/official/cv/unet/scripts/run_standalone_train.sh b/model_zoo/official/cv/unet/scripts/run_standalone_train.sh index a7fad97c09e..65b666f9f21 100644 --- a/model_zoo/official/cv/unet/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/unet/scripts/run_standalone_train.sh @@ -22,23 +22,24 @@ get_real_path() { fi } -if [ $# != 1 ] && [ $# != 2 ] +if [ $# != 2 ] && [ $# != 3 ] then echo "==============================================================================================================" echo "Please run the script as: " - echo "bash scripts/run_standalone_train.sh [DATASET] [DEVICE_ID](option, default is 0)" - echo "for example: bash run_standalone_train.sh /path/to/data/ 0" + echo "bash scripts/run_standalone_train.sh [DATASET] [CONFIG_PATH] [DEVICE_ID](option, default is 0)" + echo "for example: bash run_standalone_train.sh /path/to/data/ /path/to/config/ 0" echo "==============================================================================================================" exit 1 fi PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) export DEVICE_ID=0 -if [ $# != 1 ] +if [ $# != 2 ] then - export DEVICE_ID=$2 + export DEVICE_ID=$3 fi DATASET=$(get_real_path $1) +CONFIG_PATH=$(get_real_path $2) echo "========== start run training ===========" echo "please get log at train.log" -python ${PROJECT_DIR}/../train.py --data_url=$DATASET > train.log 2>&1 & +python ${PROJECT_DIR}/../train.py --data_path=$DATASET --config_path=$CONFIG_PATH --output './output'> train.log 2>&1 & diff --git a/model_zoo/official/cv/unet/src/config.py b/model_zoo/official/cv/unet/src/config.py deleted file mode 100644 index a40cc331989..00000000000 --- a/model_zoo/official/cv/unet/src/config.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -cfg_unet_medical = { - 'model': 'unet_medical', - 'crop': [388 / 572, 388 / 572], - 'img_size': [572, 572], - 'lr': 0.0001, - 'epochs': 400, - 'repeat': 400, - 'distribute_epochs': 1600, - 'batchsize': 16, - 'cross_valid_ind': 1, - 'num_classes': 2, - 'num_channels': 1, - - 'keep_checkpoint_max': 10, - 'weight_decay': 0.0005, - 'loss_scale': 1024.0, - 'FixedLossScaleManager': 1024.0, - - 'resume': False, - 'resume_ckpt': './', - 'transfer_training': False, - 'filter_weight': ['outc.weight', 'outc.bias'], - 'eval_activate': 'Softmax', - 'eval_resize': False -} - -cfg_unet_nested = { - 'model': 'unet_nested', - 'crop': None, - 'img_size': [576, 576], - 'lr': 0.0001, - 'epochs': 400, - 'repeat': 400, - 'distribute_epochs': 1600, - 'batchsize': 16, - 'cross_valid_ind': 1, - 'num_classes': 2, - 'num_channels': 1, - - 'keep_checkpoint_max': 10, - 'weight_decay': 0.0005, - 'loss_scale': 1024.0, - 'FixedLossScaleManager': 1024.0, - 'use_bn': True, - 'use_ds': True, - 'use_deconv': True, - - 'resume': False, - 'resume_ckpt': './', - 'transfer_training': False, - 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'], - 'eval_activate': 'Softmax', - 'eval_resize': False -} - -cfg_unet_nested_cell = { - 'model': 'unet_nested', - 'dataset': 'Cell_nuclei', - 'crop': None, - 'img_size': [96, 96], - 'lr': 3e-4, - 'epochs': 200, - 'repeat': 10, - 'distribute_epochs': 1600, - 'batchsize': 16, - 'cross_valid_ind': 1, - 'num_classes': 2, - 'num_channels': 3, - - 'keep_checkpoint_max': 10, - 'weight_decay': 0.0005, - 'loss_scale': 1024.0, - 'FixedLossScaleManager': 1024.0, - 'use_bn': True, - 'use_ds': True, - 'use_deconv': True, - - 'resume': False, - 'resume_ckpt': './', - 'transfer_training': False, - 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'], - 'eval_activate': 'Softmax', - 'eval_resize': False -} - -cfg_unet_simple = { - 'model': 'unet_simple', - 'crop': None, - 'img_size': [576, 576], - 'lr': 0.0001, - 'epochs': 400, - 'repeat': 400, - 'distribute_epochs': 2400, - 'batchsize': 16, - 'cross_valid_ind': 1, - 'num_classes': 2, - 'num_channels': 1, - - 'keep_checkpoint_max': 10, - 'weight_decay': 0.0005, - 'loss_scale': 1024.0, - 'FixedLossScaleManager': 1024.0, - - 'resume': False, - 'resume_ckpt': './', - 'transfer_training': False, - 'filter_weight': ["final.weight"], - 'eval_activate': 'Softmax', - 'eval_resize': False -} - -cfg_unet_simple_coco = { - 'model': 'unet_simple', - 'dataset': 'COCO', - 'split': 1.0, - 'img_size': [512, 512], - 'lr': 3e-4, - 'epochs': 80, - 'repeat': 1, - 'distribute_epochs': 120, - 'cross_valid_ind': 1, - 'batchsize': 16, - 'num_channels': 3, - - 'keep_checkpoint_max': 10, - 'weight_decay': 0.0005, - 'loss_scale': 1024.0, - 'FixedLossScaleManager': 1024.0, - - 'resume': False, - 'resume_ckpt': './', - 'transfer_training': False, - 'filter_weight': ["final.weight"], - 'eval_activate': 'Softmax', - 'eval_resize': False, - - 'num_classes': 81, - 'coco_classes': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', - 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', - 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', - 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', - 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', - 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', - 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', - 'teddy bear', 'hair drier', 'toothbrush'), - # change the following settings to real path - 'anno_json': '/data/coco2017/annotations/instances_train2017.json', - 'val_anno_json': '/data/coco2017/annotations/instances_val2017.json', - 'coco_dir': '/data/coco2017/train2017', - 'val_coco_dir': '/data/coco2017/val2017' -} - -cfg_unet = cfg_unet_simple -if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']: - print("ISBI dataset not support resize to original image size when in evaluation.") - cfg_unet['eval_resize'] = False diff --git a/model_zoo/official/cv/unet/src/data_loader.py b/model_zoo/official/cv/unet/src/data_loader.py index 82921647376..4bc345597b0 100644 --- a/model_zoo/official/cv/unet/src/data_loader.py +++ b/model_zoo/official/cv/unet/src/data_loader.py @@ -122,8 +122,8 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False) ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False) - if do_crop: - resize_size = [int(img_size[x] * do_crop[x]) for x in range(len(img_size))] + if do_crop != "None": + resize_size = [int(img_size[x] * do_crop[x] / 572) for x in range(len(img_size))] else: resize_size = img_size c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR) @@ -146,7 +146,7 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro train_ds = train_ds.map(input_columns="image", operations=c_resize_op) train_ds = train_ds.map(input_columns="mask", operations=c_resize_op) - if do_crop: + if do_crop != "None": train_ds = train_ds.map(input_columns="mask", operations=c_center_crop) post_process = data_post_process train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process) @@ -157,7 +157,7 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask) valid_ds = ds.zip((valid_image_ds, valid_mask_ds)) valid_ds = valid_ds.project(columns=["image", "mask"]) - if do_crop: + if do_crop != "None": valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop) post_process = data_post_process valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process) diff --git a/model_zoo/official/cv/unet/src/model_utils/config.py b/model_zoo/official/cv/unet/src/model_utils/config.py new file mode 100644 index 00000000000..e7bddbf924c --- /dev/null +++ b/model_zoo/official/cv/unet/src/model_utils/config.py @@ -0,0 +1,125 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse arguments""" + +import os +import ast +import argparse +from pprint import pprint, pformat +import yaml + +_config_path = "./unet_simple_config.yaml" + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="unet_simple_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, + help=help_description) + else: + parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + else: + raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml") + print(cfg_helper) + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../unet_simple_config.yaml"), + help="Config file path") + path_args, _ = parser.parse_known_args() + default, helper = parse_yaml(path_args.config_path) + pprint(default) + args = parse_cli_to_yaml(parser, default, helper, path_args.config_path) + final_config = merge(args, default) + return Config(final_config) + +config = get_config() diff --git a/model_zoo/official/nlp/textcnn/utils/device_adapter.py b/model_zoo/official/cv/unet/src/model_utils/device_adapter.py similarity index 77% rename from model_zoo/official/nlp/textcnn/utils/device_adapter.py rename to model_zoo/official/cv/unet/src/model_utils/device_adapter.py index 92439de46b7..9c3d21d5e47 100644 --- a/model_zoo/official/nlp/textcnn/utils/device_adapter.py +++ b/model_zoo/official/cv/unet/src/model_utils/device_adapter.py @@ -15,12 +15,12 @@ """Device adapter for ModelArts""" -from utils.config import config +from src.model_utils.config import config if config.enable_modelarts: - from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id else: - from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id __all__ = [ "get_device_id", "get_device_num", "get_rank_id", "get_job_id" diff --git a/model_zoo/official/nlp/textcnn/utils/local_adapter.py b/model_zoo/official/cv/unet/src/model_utils/local_adapter.py similarity index 100% rename from model_zoo/official/nlp/textcnn/utils/local_adapter.py rename to model_zoo/official/cv/unet/src/model_utils/local_adapter.py diff --git a/model_zoo/official/nlp/textcnn/utils/moxing_adapter.py b/model_zoo/official/cv/unet/src/model_utils/moxing_adapter.py similarity index 98% rename from model_zoo/official/nlp/textcnn/utils/moxing_adapter.py rename to model_zoo/official/cv/unet/src/model_utils/moxing_adapter.py index 420d4808f04..aabd5ac6cf1 100644 --- a/model_zoo/official/nlp/textcnn/utils/moxing_adapter.py +++ b/model_zoo/official/cv/unet/src/model_utils/moxing_adapter.py @@ -18,7 +18,7 @@ import os import functools from mindspore import context -from utils.config import config +from src.model_utils.config import config _global_sync_count = 0 diff --git a/model_zoo/official/cv/unet/src/utils.py b/model_zoo/official/cv/unet/src/utils.py index 72b6956a30c..c03b3230f3d 100644 --- a/model_zoo/official/cv/unet/src/utils.py +++ b/model_zoo/official/cv/unet/src/utils.py @@ -21,6 +21,7 @@ from mindspore import nn from mindspore.ops import operations as ops from mindspore.train.callback import Callback from mindspore.common.tensor import Tensor +from src.model_utils.config import config class UnetEval(nn.Cell): """ @@ -63,10 +64,9 @@ def apply_eval(eval_param_dict): class dice_coeff(nn.Metric): """Unet Metric, return dice coefficient and IOU.""" - def __init__(self, cfg_unet, print_res=True): + def __init__(self, print_res=True): super(dice_coeff, self).__init__() self.clear() - self.cfg_unet = cfg_unet self.print_res = print_res def clear(self): @@ -84,20 +84,20 @@ class dice_coeff(nn.Metric): if b != 1: raise ValueError('Batch size should be 1 when in evaluation.') y = y.reshape((h, w, c)) - if self.cfg_unet["eval_activate"].lower() == "softmax": + if config.eval_activate.lower() == "softmax": y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) - if self.cfg_unet["eval_resize"]: + if config.eval_resize: y_pred = [] - for i in range(self.cfg_unet["num_classes"]): + for i in range(config.num_classes): y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) y_pred = np.stack(y_pred, axis=-1) else: y_pred = y_softmax - elif self.cfg_unet["eval_activate"].lower() == "argmax": + elif config.eval_activate.lower() == "argmax": y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) y_pred = [] - for i in range(self.cfg_unet["num_classes"]): - if self.cfg_unet["eval_resize"]: + for i in range(config.num_classes): + if config.eval_resize: y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) else: y_pred.append(np.float32(y_argmax == i)) diff --git a/model_zoo/official/cv/unet/train.py b/model_zoo/official/cv/unet/train.py index e0bb07f0ed7..7e7c7d1ecd0 100644 --- a/model_zoo/official/cv/unet/train.py +++ b/model_zoo/official/cv/unet/train.py @@ -14,14 +14,12 @@ # ============================================================================ import os -import argparse import logging -import ast import mindspore import mindspore.nn as nn from mindspore import Model, context -from mindspore.communication.management import init, get_group_size, get_rank +from mindspore.communication.management import init from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -31,133 +29,110 @@ from src.unet_nested import NestedUNet, UNet from src.data_loader import create_dataset, create_multi_class_dataset from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff -from src.config import cfg_unet from src.eval_callback import EvalCallBack -device_id = int(os.getenv('DEVICE_ID')) +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.device_adapter import get_rank_id, get_device_num + +device_id = int(os.getenv("DEVICE_ID")) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) mindspore.set_seed(1) -def train_net(args_opt, - cross_valid_ind=1, +@moxing_wrapper() +def train_net(cross_valid_ind=1, epochs=400, batch_size=16, - lr=0.0001, - cfg=None): + lr=0.0001): rank = 0 group_size = 1 - data_dir = args_opt.data_url - run_distribute = args_opt.run_distribute + data_dir = config.data_path + run_distribute = config.run_distribute if run_distribute: init() - group_size = get_group_size() - rank = get_rank() + group_size = get_device_num() + rank = get_rank_id() parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=False) need_slice = False - if cfg['model'] == 'unet_medical': - net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) - elif cfg['model'] == 'unet_nested': - net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], - use_bn=cfg['use_bn'], use_ds=cfg['use_ds']) - need_slice = cfg['use_ds'] - elif cfg['model'] == 'unet_simple': - net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + if config.model_name == 'unet_medical': + net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes) + elif config.model_name == 'unet_nested': + net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv, + use_bn=config.use_bn, use_ds=config.use_ds) + need_slice = config.use_ds + elif config.model_name == 'unet_simple': + net = UNet(in_channel=config.num_channels, n_class=config.num_classes) else: - raise ValueError("Unsupported model: {}".format(cfg['model'])) + raise ValueError("Unsupported model: {}".format(config.model_name)) - if cfg['resume']: - param_dict = load_checkpoint(cfg['resume_ckpt']) - if cfg['transfer_training']: - filter_checkpoint_parameter_by_list(param_dict, cfg['filter_weight']) + if config.resume: + param_dict = load_checkpoint(config.resume_ckpt) + if config.transfer_training: + filter_checkpoint_parameter_by_list(param_dict, config.filter_weight) load_param_into_net(net, param_dict) - if 'use_ds' in cfg and cfg['use_ds']: + if hasattr(config, "use_ds") and config.use_ds: criterion = MultiCrossEntropyWithLogits() else: criterion = CrossEntropyWithLogits() - if 'dataset' in cfg and cfg['dataset'] != "ISBI": - repeat = cfg['repeat'] if 'repeat' in cfg else 1 - split = cfg['split'] if 'split' in cfg else 0.8 + if hasattr(config, "dataset") and config.dataset != "ISBI": dataset_sink_mode = True per_print_times = 0 - train_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], repeat, batch_size, - num_classes=cfg['num_classes'], is_train=True, augment=True, + repeat = config.repeat if hasattr(config, "repeat") else 1 + split = config.split if hasattr(config, "split") else 0.8 + train_dataset = create_multi_class_dataset(data_dir, config.image_size, repeat, batch_size, + num_classes=config.num_classes, is_train=True, augment=True, split=split, rank=rank, group_size=group_size, shuffle=True) - valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1, - num_classes=cfg['num_classes'], is_train=False, - eval_resize=cfg["eval_resize"], split=split, + valid_dataset = create_multi_class_dataset(data_dir, config.image_size, 1, 1, + num_classes=config.num_classes, is_train=False, + eval_resize=config.eval_resize, split=split, python_multiprocessing=False, shuffle=False) else: - repeat = cfg['repeat'] + repeat = config.repeat dataset_sink_mode = False per_print_times = 1 train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, - run_distribute, cfg["crop"], cfg['img_size']) + run_distribute, config.crop, config.image_size) train_data_size = train_dataset.get_dataset_size() print("dataset length is:", train_data_size) + ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, - keep_checkpoint_max=cfg['keep_checkpoint_max']) - ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']), - directory='./ckpt_{}/'.format(device_id), + keep_checkpoint_max=config.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(config.model_name), + directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id), config=ckpt_config) - optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'], - loss_scale=cfg['loss_scale']) + optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay, + loss_scale=config.loss_scale) - loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False) + loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.FixedLossScaleManager, False) model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") print("============== Starting Training ==============") callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb] - if args_opt.run_eval: + if config.run_eval: eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(), - metrics={"dice_coeff": dice_coeff(cfg_unet, False)}) - eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": args_opt.eval_metrics} - eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, - eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, - ckpt_directory='./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt", - metrics_name=args_opt.eval_metrics) + metrics={"dice_coeff": dice_coeff(False)}) + eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics} + eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval, + eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True, + ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt", + metrics_name=config.eval_metrics) callbacks.append(eval_cb) model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode) print("============== End Training ==============") -def get_args(): - parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', - help='data directory') - parser.add_argument('-t', '--run_distribute', type=ast.literal_eval, - default=False, help='Run distribute, default: false.') - parser.add_argument("--run_eval", type=ast.literal_eval, default=False, - help="Run evaluation when training, default is False.") - parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, - help="Save best checkpoint when run_eval is True, default is True.") - parser.add_argument("--eval_start_epoch", type=int, default=0, - help="Evaluation start epoch when run_eval is True, default is 0.") - parser.add_argument("--eval_interval", type=int, default=1, - help="Evaluation interval when run_eval is True, default is 1.") - parser.add_argument("--eval_metrics", type=str, default="dice_coeff", choices=("dice_coeff", "iou"), - help="Evaluation metrics when run_eval is True, support [dice_coeff, iou], " - "default is dice_coeff.") - - return parser.parse_args() - - if __name__ == '__main__': logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') - args = get_args() - print("Training setting:", args) - epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs'] - train_net(args_opt=args, - cross_valid_ind=cfg_unet['cross_valid_ind'], + epoch_size = config.epochs if not config.run_distribute else config.distribute_epochs + train_net(cross_valid_ind=config.cross_valid_ind, epochs=epoch_size, - batch_size=cfg_unet['batchsize'], - lr=cfg_unet['lr'], - cfg=cfg_unet) + batch_size=config.batch_size, + lr=config.lr) diff --git a/model_zoo/official/cv/unet/unet_medical_config.yaml b/model_zoo/official/cv/unet/unet_medical_config.yaml new file mode 100644 index 00000000000..d1a4290f834 --- /dev/null +++ b/model_zoo/official/cv/unet/unet_medical_config.yaml @@ -0,0 +1,67 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +enable_profiling: False + +# ============================================================================== +# Training options +model_name: 'unet_medical' +run_eval: False +run_distribute: False +crop: [388, 388] +image_size : [572, 572] +lr: 0.0001 +epochs: 400 +repeat: 400 +distribute_epochs: 1600 +batch_size: 16 +cross_valid_ind: 1 +num_classes: 2 +num_channels: 1 +weight_decay: 0.0005 +loss_scale: 1024.0 +FixedLossScaleManager: 1024.0 +resume: False +resume_ckpt: './' +transfer_training: False +filter_weight: ['outc.weight', 'outc.bias'] + +#Eval options +keep_checkpoint_max: 10 +eval_activate: 'Softmax' +eval_resize: False +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'ckpt_unet_medical_adam-4-75.ckpt' +rst_path: './result_Files/' + +# Export options +width: 572 +height: 572 +file_name: unet +file_format: AIR + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +weight_decay: "Weight decay." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +checkpoint_path: "The location of the checkpoint file." +checkpoint_file_path: "The location of the checkpoint file." diff --git a/model_zoo/official/cv/unet/unet_nested_cell_config.yaml b/model_zoo/official/cv/unet/unet_nested_cell_config.yaml new file mode 100644 index 00000000000..87e3f4d4655 --- /dev/null +++ b/model_zoo/official/cv/unet/unet_nested_cell_config.yaml @@ -0,0 +1,71 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +enable_profiling: False + +# ============================================================================== +# Training options +model_name: 'unet_nested' +run_eval: False +run_distribute: False +dataset: 'Cell_nuclei' +crop: None +image_size : [96, 96] +lr: 0.0003 +epochs: 200 +repeat: 10 +distribute_epochs: 1600 +batch_size: 16 +cross_valid_ind: 1 +num_classes: 2 +num_channels: 3 +weight_decay: 0.0005 +loss_scale: 1024.0 +FixedLossScaleManager: 1024.0 +use_ds: False +use_bn: False +use_deconv: True +resume: False +resume_ckpt: './' +transfer_training: False +filter_weight: ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] + +#Eval options +keep_checkpoint_max: 10 +eval_activate: 'Softmax' +eval_resize: False +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'ckpt_unet_nested_adam-4-75.ckpt' +rst_path: './result_Files/' + +# Export options +width: 572 +height: 572 +file_name: unet +file_format: AIR + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +weight_decay: "Weight decay." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +checkpoint_path: "The location of the checkpoint file." +checkpoint_file_path: "The location of the checkpoint file." diff --git a/model_zoo/official/cv/unet/unet_nested_config.yaml b/model_zoo/official/cv/unet/unet_nested_config.yaml new file mode 100644 index 00000000000..859c47b9a19 --- /dev/null +++ b/model_zoo/official/cv/unet/unet_nested_config.yaml @@ -0,0 +1,71 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +enable_profiling: False + +# ============================================================================== +# Training options +model_name: 'unet_nested' +run_eval: False +run_distribute: False +crop: None +image_size : [576, 576] +lr: 0.0001 +epochs: 400 +repeat: 400 +distribute_epochs: 1600 +batch_size: 16 +cross_valid_ind: 1 +num_classes: 2 +num_channels: 1 +weight_decay: 0.0005 +loss_scale: 1024.0 +FixedLossScaleManager: 1024.0 +use_ds: True +use_bn: True +use_deconv: True +resume: False +resume_ckpt: './' +transfer_training: False +filter_weight: ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] + +#Eval options +keep_checkpoint_max: 10 +eval_activate: 'Softmax' +eval_resize: False +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'ckpt_unet_nested_adam-4-75.ckpt' +rst_path: './result_Files/' + +# Export options +width: 572 +height: 572 +file_name: unet +file_format: AIR + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +weight_decay: "Weight decay." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +checkpoint_path: "The location of the checkpoint file." +checkpoint_file_path: "The location of the checkpoint file." + diff --git a/model_zoo/official/cv/unet/unet_simple_coco_config.yaml b/model_zoo/official/cv/unet/unet_simple_coco_config.yaml new file mode 100644 index 00000000000..e17e3e62f62 --- /dev/null +++ b/model_zoo/official/cv/unet/unet_simple_coco_config.yaml @@ -0,0 +1,91 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +enable_profiling: False + +# ============================================================================== +# Training options +model_name: 'unet_simple' +run_eval: False +run_distribute: False +dataset: 'COCO' +split : 1.0 +image_size : [512, 512] +lr: 0.0001 +epochs: 80 +repeat: 1 +distribute_epochs: 120 +batch_size: 16 +cross_valid_ind: 1 +num_classes: 81 +num_channels: 3 +weight_decay: 0.0005 +loss_scale: 1024.0 +FixedLossScaleManager: 1024.0 +resume: False +resume_ckpt: './' +transfer_training: False +filter_weight: ['final1.weight'] +coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'] + +anno_json: '/data/coco2017/annotations/instances_train2017.json' +val_anno_json: '/data/coco2017/annotations/instances_val2017.json' +coco_dir: '/data/coco2017/train2017' +val_coco_dir: '/data/coco2017/val2017' + +#Eval options +eval_metrics: "dice_coeff" +eval_start_epoch: 0 +eval_interval: 1 +keep_checkpoint_max: 10 +eval_activate: 'Softmax' +eval_resize: False +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'ckpt_unet_simple_adam-4-75.ckpt' +rst_path: './result_Files/' + +# Export options +width: 572 +height: 572 +file_name: unet +file_format: AIR + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +weight_decay: "Weight decay." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +checkpoint_path: "The location of the checkpoint file." +checkpoint_file_path: "The location of the checkpoint file." diff --git a/model_zoo/official/cv/unet/unet_simple_config.yaml b/model_zoo/official/cv/unet/unet_simple_config.yaml new file mode 100644 index 00000000000..4ba5db3eafc --- /dev/null +++ b/model_zoo/official/cv/unet/unet_simple_config.yaml @@ -0,0 +1,68 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +enable_profiling: False + +# ============================================================================== +# Training options +model_name: 'unet_simple' +run_eval: False +run_distribute: False +crop: None +image_size : [576, 576] +lr: 0.0001 +epochs: 400 +repeat: 400 +distribute_epochs: 2400 +batch_size: 16 +cross_valid_ind: 1 +num_classes: 2 +num_channels: 1 +weight_decay: 0.0005 +loss_scale: 1024.0 +FixedLossScaleManager: 1024.0 +resume: False +resume_ckpt: './' +transfer_training: False +filter_weight: ['final1.weight'] + +#Eval options +keep_checkpoint_max: 10 +eval_activate: 'Softmax' +eval_resize: False +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'ckpt_unet_simple_adam-4-75.ckpt' +rst_path: './result_Files/' + +# Export options +width: 572 +height: 572 +file_name: unet +file_format: AIR + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +weight_decay: "Weight decay." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +checkpoint_path: "The location of the checkpoint file." +checkpoint_file_path: "The location of the checkpoint file." + diff --git a/model_zoo/official/cv/unet3d/scripts/run_distribute_train.sh b/model_zoo/official/cv/unet3d/scripts/run_distribute_train.sh index d3a886a2344..443d5179030 100644 --- a/model_zoo/official/cv/unet3d/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/unet3d/scripts/run_distribute_train.sh @@ -16,7 +16,7 @@ if [ $# -ne 2 ] then - echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]" + echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [DATA_PATH]" exit 1 fi diff --git a/model_zoo/official/cv/unet3d/scripts/run_standalone_eval.sh b/model_zoo/official/cv/unet3d/scripts/run_standalone_eval.sh index f377be58ddd..dfb301f354f 100644 --- a/model_zoo/official/cv/unet3d/scripts/run_standalone_eval.sh +++ b/model_zoo/official/cv/unet3d/scripts/run_standalone_eval.sh @@ -18,14 +18,14 @@ if [ $# != 2 ] then echo "==============================================================================================================" echo "Please run the script as: " - echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]" + echo "bash scripts/run_standalone_eval.sh [DATA_PATH] [CHECKPOINT]" echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/" echo "==============================================================================================================" fi if [ $# != 2 ] then - echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]" + echo "Usage: sh run_eval_ascend.sh [DATA_PATH] [CHECKPOINT_PATH]" exit 1 fi diff --git a/model_zoo/official/cv/unet3d/scripts/run_standalone_train.sh b/model_zoo/official/cv/unet3d/scripts/run_standalone_train.sh index 1e30cea5c7c..a04b14d08c0 100644 --- a/model_zoo/official/cv/unet3d/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/unet3d/scripts/run_standalone_train.sh @@ -16,7 +16,7 @@ if [ $# -ne 1 ] then - echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]" + echo "Usage: sh run_distribute_train_ascend.sh [DATA_PATH]" exit 1 fi diff --git a/model_zoo/official/nlp/textcnn/README.md b/model_zoo/official/nlp/textcnn/README.md index 6f35b29dd1e..8e7997def91 100644 --- a/model_zoo/official/nlp/textcnn/README.md +++ b/model_zoo/official/nlp/textcnn/README.md @@ -100,26 +100,26 @@ If you want to run in modelarts, please check the official documentation of [mod ```bash ├── model_zoo - ├── README.md // descriptions about all the models + ├── README.md // descriptions about all the models ├── textcnn - ├── README.md // descriptions about textcnn + ├── README.md // descriptions about textcnn ├──scripts - │ ├── run_train.sh // shell script for distributed on Ascend - │ ├── run_eval.sh // shell script for evaluation on Ascend + │ ├── run_train.sh // shell script for distributed on Ascend + │ ├── run_eval.sh // shell script for evaluation on Ascend ├── src - │ ├── dataset.py // Processing dataset - │ ├── textcnn.py // textcnn architecture - ├── utils - │ ├──device_adapter.py // device adapter - │ ├──local_adapter.py // local adapter - │ ├──moxing_adapter.py // moxing adapter - │ ├── config.py // parameter analysis + │ ├── dataset.py // Processing dataset + │ ├── textcnn.py // textcnn architecture + ├── utils + │ ├──device_adapter.py // device adapter + │ ├──local_adapter.py // local adapter + │ ├──moxing_adapter.py // moxing adapter + │ ├──config.py // parameter analysis ├── mr_config.yaml // parameter configuration - ├── sst2_config.yaml // parameter configuration - ├── subj_config.yaml // parameter configuration - ├── train.py // training script - ├── eval.py // evaluation script - ├── export.py // export checkpoint to other format file + ├── sst2_config.yaml // parameter configuration + ├── subj_config.yaml // parameter configuration + ├── train.py // training script + ├── eval.py // evaluation script + ├── export.py // export checkpoint to other format file ``` ## [Script Parameters](#contents) diff --git a/model_zoo/official/nlp/textcnn/eval.py b/model_zoo/official/nlp/textcnn/eval.py index c1fc4d644ad..9b176d3fc61 100644 --- a/model_zoo/official/nlp/textcnn/eval.py +++ b/model_zoo/official/nlp/textcnn/eval.py @@ -22,9 +22,9 @@ from mindspore import context from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from utils.moxing_adapter import moxing_wrapper -from utils.device_adapter import get_device_id -from utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.device_adapter import get_device_id +from src.model_utils.config import config from src.textcnn import TextCNN from src.dataset import MovieReview, SST2, Subjectivity diff --git a/model_zoo/official/nlp/textcnn/export.py b/model_zoo/official/nlp/textcnn/export.py index fe75e83e5ad..99494af5a18 100644 --- a/model_zoo/official/nlp/textcnn/export.py +++ b/model_zoo/official/nlp/textcnn/export.py @@ -20,7 +20,7 @@ import numpy as np from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context -from utils.config import config +from src.model_utils.config import config from src.textcnn import TextCNN from src.dataset import MovieReview, SST2, Subjectivity diff --git a/model_zoo/official/nlp/textcnn/mr_config.yaml b/model_zoo/official/nlp/textcnn/mr_config.yaml index 575418404f8..c316b68650d 100644 --- a/model_zoo/official/nlp/textcnn/mr_config.yaml +++ b/model_zoo/official/nlp/textcnn/mr_config.yaml @@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt' word_len: 51 vec_length: 40 base_lr: 1e-3 +label_dir: "" +result_dir: "" +result_path: './preprocess_Result/' # Export options device_id: 0 diff --git a/model_zoo/official/nlp/textcnn/postprocess.py b/model_zoo/official/nlp/textcnn/postprocess.py index e14e490e85e..6ba4f891a52 100644 --- a/model_zoo/official/nlp/textcnn/postprocess.py +++ b/model_zoo/official/nlp/textcnn/postprocess.py @@ -16,35 +16,21 @@ ##############postprocess################# """ import os -import argparse import numpy as np from mindspore.nn.metrics import Accuracy -from src.config import cfg_mr, cfg_subj, cfg_sst2 - - -parser = argparse.ArgumentParser(description='postprocess') -parser.add_argument('--label_dir', type=str, default="", help='label data dir') -parser.add_argument('--result_dir', type=str, default="", help="infer result dir") -parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2']) -args = parser.parse_args() +from src.model_utils.config import config if __name__ == '__main__': - if args.dataset == 'MR': - cfg = cfg_mr - elif args.dataset == 'SUBJ': - cfg = cfg_subj - elif args.dataset == 'SST2': - cfg = cfg_sst2 - file_prefix = 'textcnn_bs' + str(cfg.batch_size) + '_' + file_prefix = 'textcnn_bs' + str(config.batch_size) + '_' metric = Accuracy() metric.clear() - label_list = np.load(args.label_dir, allow_pickle=True) + label_list = np.load(config.label_dir, allow_pickle=True) for idx, label in enumerate(label_list): - pred = np.fromfile(os.path.join(args.result_dir, file_prefix + str(idx) + '_0.bin'), np.float32) - pred = pred.reshape(cfg.batch_size, int(pred.shape[0]/cfg.batch_size)) + pred = np.fromfile(os.path.join(config.result_dir, file_prefix + str(idx) + '_0.bin'), np.float32) + pred = pred.reshape(config.batch_size, int(pred.shape[0]/config.batch_size)) metric.update(pred, label) accuracy = metric.eval() print("accuracy: ", accuracy) diff --git a/model_zoo/official/nlp/textcnn/preprocess.py b/model_zoo/official/nlp/textcnn/preprocess.py index 4bc2374be4c..c008e06787d 100644 --- a/model_zoo/official/nlp/textcnn/preprocess.py +++ b/model_zoo/official/nlp/textcnn/preprocess.py @@ -15,37 +15,28 @@ """ ##############preprocess textcnn example on movie review################# """ -import argparse import os import numpy as np -from src.config import cfg_mr, cfg_subj, cfg_sst2 +from src.model_utils.config import config from src.dataset import MovieReview, SST2, Subjectivity -parser = argparse.ArgumentParser(description='TextCNN') -parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2']) -parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path') -args_opt = parser.parse_args() - if __name__ == '__main__': - if args_opt.dataset == 'MR': - cfg = cfg_mr - instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - elif args_opt.dataset == 'SUBJ': - cfg = cfg_subj - instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - elif args_opt.dataset == 'SST2': - cfg = cfg_sst2 - instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) + if config.dataset == 'MR': + instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) + elif config.dataset == 'SUBJ': + instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9) + elif config.dataset == 'SST2': + instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) - dataset = instance.create_test_dataset(batch_size=cfg.batch_size) - img_path = os.path.join(args_opt.result_path, "00_data") + dataset = instance.create_test_dataset(batch_size=config.batch_size) + img_path = os.path.join(config.result_path, "00_data") os.makedirs(img_path) label_list = [] for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)): - file_name = "textcnn_bs" + str(cfg.batch_size) + "_" + str(i) + ".bin" + file_name = "textcnn_bs" + str(config.batch_size) + "_" + str(i) + ".bin" file_path = img_path + "/" + file_name data['data'].tofile(file_path) label_list.append(data['label']) - np.save(args_opt.result_path + "label_ids.npy", label_list) + np.save(config.result_path + "label_ids.npy", label_list) print("="*20, "export bin files finished", "="*20) diff --git a/model_zoo/official/nlp/textcnn/utils/config.py b/model_zoo/official/nlp/textcnn/src/model_utils/config.py similarity index 100% rename from model_zoo/official/nlp/textcnn/utils/config.py rename to model_zoo/official/nlp/textcnn/src/model_utils/config.py diff --git a/model_zoo/official/nlp/textcnn/src/model_utils/device_adapter.py b/model_zoo/official/nlp/textcnn/src/model_utils/device_adapter.py new file mode 100644 index 00000000000..9c3d21d5e47 --- /dev/null +++ b/model_zoo/official/nlp/textcnn/src/model_utils/device_adapter.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Device adapter for ModelArts""" + +from src.model_utils.config import config + +if config.enable_modelarts: + from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + "get_device_id", "get_device_num", "get_rank_id", "get_job_id" +] diff --git a/model_zoo/official/nlp/textcnn/src/model_utils/local_adapter.py b/model_zoo/official/nlp/textcnn/src/model_utils/local_adapter.py new file mode 100644 index 00000000000..769fa6dc78e --- /dev/null +++ b/model_zoo/official/nlp/textcnn/src/model_utils/local_adapter.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Local adapter""" + +import os + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/model_zoo/official/nlp/textcnn/src/model_utils/moxing_adapter.py b/model_zoo/official/nlp/textcnn/src/model_utils/moxing_adapter.py new file mode 100644 index 00000000000..aabd5ac6cf1 --- /dev/null +++ b/model_zoo/official/nlp/textcnn/src/model_utils/moxing_adapter.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Moxing adapter for ModelArts""" + +import os +import functools +from mindspore import context +from src.model_utils.config import config + +_global_sync_count = 0 + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs. + """ + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + run_func(*args, **kwargs) + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + return wrapped_func + return wrapper diff --git a/model_zoo/official/nlp/textcnn/sst2_config.yaml b/model_zoo/official/nlp/textcnn/sst2_config.yaml index a5d3f72bb70..2dc19b53754 100644 --- a/model_zoo/official/nlp/textcnn/sst2_config.yaml +++ b/model_zoo/official/nlp/textcnn/sst2_config.yaml @@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt' word_len: 51 vec_length: 40 base_lr: 5e-3 +label_dir: "" +result_dir: "" +result_path: './preprocess_Result/' # Export options device_id: 0 diff --git a/model_zoo/official/nlp/textcnn/subj_config.yaml b/model_zoo/official/nlp/textcnn/subj_config.yaml index 0a81f9729b8..f1615aae429 100644 --- a/model_zoo/official/nlp/textcnn/subj_config.yaml +++ b/model_zoo/official/nlp/textcnn/subj_config.yaml @@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt' word_len: 51 vec_length: 40 base_lr: 8e-4 +label_dir: "" +result_dir: "" +result_path: './preprocess_Result/' # Export options device_id: 0 diff --git a/model_zoo/official/nlp/textcnn/train.py b/model_zoo/official/nlp/textcnn/train.py index 0ff4a289879..8064a76a478 100644 --- a/model_zoo/official/nlp/textcnn/train.py +++ b/model_zoo/official/nlp/textcnn/train.py @@ -26,9 +26,9 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from utils.moxing_adapter import moxing_wrapper -from utils.device_adapter import get_device_id, get_rank_id -from utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.device_adapter import get_device_id, get_rank_id +from src.model_utils.config import config from src.textcnn import TextCNN from src.textcnn import SoftmaxCrossEntropyExpand from src.dataset import MovieReview, SST2, Subjectivity