!16400 modify unet for cloud and modify model_utils of textcnn

From: @Somnus2020
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-05-15 16:15:36 +08:00 committed by Gitee
commit 844afd374b
39 changed files with 969 additions and 485 deletions

View File

@ -160,6 +160,35 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
Then you can run everything just like on ascend. Then you can run everything just like on ascend.
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
```python
# run distributed training on modelarts example
# (1) First, Perform a or b.
# a. Set "enable_modelarts=True" on yaml file.
# Set other parameters on yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add other parameters on the website UI interface.
# (2) Set the config directory to "config_path=/The path of config in S3/"
# (3) Set the code directory to "/path/unet" on the website UI interface.
# (4) Set the startup file to "train.py" on the website UI interface.
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (6) Create your job.
# run evaluation on modelarts example
# (1) Copy or upload your trained model to S3 bucket.
# (2) Perform a or b.
# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file.
# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file.
# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
# (3) Set the config directory to "config_path=/The path of config in S3/"
# (4) Set the code directory to "/path/unet" on the website UI interface.
# (5) Set the startup file to "eval.py" on the website UI interface.
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (7) Create your job.
```
## [Script Description](#contents) ## [Script Description](#contents)
### [Script and Sample Code](#contents) ### [Script and Sample Code](#contents)
@ -190,6 +219,16 @@ Then you can run everything just like on ascend.
├──__init__.py // init file ├──__init__.py // init file
├──unet_model.py // unet model ├──unet_model.py // unet model
├──unet_parts.py // unet part ├──unet_parts.py // unet part
├── model_utils
│ ├── config.py // parameter configuration
│ ├── device_adapter.py // device adapter
│ ├── local_adapter.py // local adapter
│ ├── moxing_adapter.py // moxing adapter
├── unet_medical_config.yaml // parameter configuration
├── unet_nested_cell_config.yaml // parameter configuration
├── unet_nested_config.yaml // parameter configuration
├── unet_simple_config.yaml // parameter configuration
├── unet_simple_coco_config.yaml // parameter configuration
├── train.py // training script ├── train.py // training script
├── eval.py // evaluation script ├── eval.py // evaluation script
├── export.py // export script ├── export.py // export script

View File

@ -164,6 +164,38 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
然后在容器里的操作就和Ascend平台上是一样的。 然后在容器里的操作就和Ascend平台上是一样的。
如果要在modelarts上进行模型的训练可以参考modelarts的官方指导文档(https://support.huaweicloud.com/modelarts/)
开始进行模型的训练和推理,具体操作如下:
```python
# 在modelarts上使用分布式训练的示例
# (1) 选址a或者b其中一种方式。
# a. 设置 "enable_modelarts=True" 。
# 在yaml文件上设置网络所需的参数。
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
# 在modelarts的界面上设置网络所需的参数。
# (2)设置网络配置文件的路径 "config_path=/The path of config in S3/"
# (3) 在modelarts的界面上设置代码的路径 "/path/unet"。
# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。
# (5) 在modelarts的界面上设置模型的数据路径 "Dataset path" ,
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
# (6) 开始模型的训练。
# 在modelarts上使用模型推理的示例
# (1) 把训练好的模型地方到桶的对应位置。
# (2) 选址a或者b其中一种方式。
# a. 设置 "checkpoint_file_path='/cache/checkpoint_path/model.ckpt" 在 yaml 文件.
# 设置 "checkpoint_url=/The path of checkpoint in S3/" 在 yaml 文件.
# b. 增加 "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" 参数在modearts的界面上。
# 增加 "checkpoint_url=/The path of checkpoint in S3/" 参数在modearts的界面上。
# (3) 设置网络配置文件的路径 "config_path=/The path of config in S3/"
# (4) 在modelarts的界面上设置代码的路径 "/path/unet"。
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。
# (6) 在modelarts的界面上设置模型的数据路径 "Dataset path" ,
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
# (7) 开始模型的推理。
```
## 脚本说明 ## 脚本说明
### 脚本及样例代码 ### 脚本及样例代码
@ -194,6 +226,16 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
├──__init__.py ├──__init__.py
├──unet_model.py // Unet++ 网络结构 ├──unet_model.py // Unet++ 网络结构
├──unet_parts.py // Unet++ 子网 ├──unet_parts.py // Unet++ 子网
├── model_utils
│ ├── config.py // 参数配置
│ ├── device_adapter.py // 设备配置
│ ├── local_adapter.py // 本地设备配置
│ ├── moxing_adapter.py // modelarts设备配置
├── unet_medical_config.yaml // 配置文件
├── unet_nested_cell_config.yaml // 配置文件
├── unet_nested_config.yaml // 配置文件
├── unet_simple_config.yaml // 配置文件
├── unet_simple_coco_config.yaml // 配置文件
├── train.py // 训练脚本 ├── train.py // 训练脚本
├── eval.py // 推理脚本 ├── eval.py // 推理脚本
├── export.py // 导出脚本 ├── export.py // 导出脚本

View File

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
import os import os
import argparse
import logging import logging
from mindspore import context, Model from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
@ -22,38 +21,39 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.data_loader import create_dataset, create_multi_class_dataset from src.data_loader import create_dataset, create_multi_class_dataset
from src.unet_medical import UNetMedical from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet
from src.utils import UnetEval, TempLoss, dice_coeff from src.utils import UnetEval, TempLoss, dice_coeff
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv("DEVICE_ID"))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
@moxing_wrapper()
def test_net(data_dir, def test_net(data_dir,
ckpt_path, ckpt_path,
cross_valid_ind=1, cross_valid_ind=1):
cfg=None): if config.model_name == 'unet_medical':
if cfg['model'] == 'unet_medical': net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes)
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) elif config.model_name == 'unet_nested':
elif cfg['model'] == 'unet_nested': net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv,
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], use_bn=config.use_bn, use_ds=False)
use_bn=cfg['use_bn'], use_ds=False) elif config.model_name == 'unet_simple':
elif cfg['model'] == 'unet_simple': net = UNet(in_channel=config.num_channels, n_class=config.num_classes)
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else: else:
raise ValueError("Unsupported model: {}".format(cfg['model'])) raise ValueError("Unsupported model: {}".format(config.model_name))
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net = UnetEval(net) net = UnetEval(net)
if 'dataset' in cfg and cfg['dataset'] != "ISBI": if hasattr(config, "dataset") and config.dataset != "ISBI":
split = cfg['split'] if 'split' in cfg else 0.8 split = config.split if hasattr(config, "dataset") else 0.8
valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1, valid_dataset = create_multi_class_dataset(data_dir, config.image_size, 1, 1,
num_classes=cfg['num_classes'], is_train=False, num_classes=config.num_classes, is_train=False,
eval_resize=cfg["eval_resize"], split=split, eval_resize=config.eval_resize, split=split,
python_multiprocessing=False, shuffle=False) python_multiprocessing=False, shuffle=False)
else: else:
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size']) do_crop=config.crop, img_size=config.image_size)
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(cfg_unet)}) model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()})
print("============== Starting Evaluating ============") print("============== Starting Evaluating ============")
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]
@ -61,22 +61,8 @@ def test_net(data_dir,
print("============== Cross valid IOU is:", eval_score[1]) print("============== Cross valid IOU is:", eval_score[1])
def get_args():
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='data directory')
parser.add_argument('-p', '--ckpt_path', dest='ckpt_path', type=str, default='ckpt_unet_medical_adam-1_600.ckpt',
help='checkpoint path')
return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args() test_net(data_dir=config.data_path,
print("Testing setting:", args) ckpt_path=config.checkpoint_file_path,
test_net(data_dir=args.data_url, cross_valid_ind=config.cross_valid_ind)
ckpt_path=args.ckpt_path,
cross_valid_ind=cfg_unet['cross_valid_ind'],
cfg=cfg_unet)

View File

@ -13,46 +13,36 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import argparse
import numpy as np import numpy as np
from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context
from src.unet_medical.unet_model import UNetMedical from src.unet_medical.unet_model import UNetMedical
from src.unet_nested import NestedUNet, UNet from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet as cfg
from src.utils import UnetEval from src.utils import UnetEval
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
parser = argparse.ArgumentParser(description='unet export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument('--width', type=int, default=572, help='input width')
parser.add_argument('--height', type=int, default=572, help='input height')
parser.add_argument("--file_name", type=str, default="unet", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if args.device_target == "Ascend": if config.device_target == "Ascend":
context.set_context(device_id=args.device_id) context.set_context(device_id=get_device_id())
if __name__ == "__main__": if __name__ == "__main__":
if cfg['model'] == 'unet_medical': if config.model == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes)
elif cfg['model'] == 'unet_nested': elif config.model == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv,
use_bn=cfg['use_bn'], use_ds=False) use_bn=config.use_bn, use_ds=False)
elif cfg['model'] == 'unet_simple': elif config.model == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) net = UNet(in_channel=config.num_channels, n_class=config.num_classes)
else: else:
raise ValueError("Unsupported model: {}".format(cfg['model'])) raise ValueError("Unsupported model: {}".format(config.model))
# return a parameter dict for model # return a parameter dict for model
param_dict = load_checkpoint(args.ckpt_file) param_dict = load_checkpoint(config.checkpoint_file_path)
# load the parameter into net # load the parameter into net
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net = UnetEval(net) net = UnetEval(net)
input_data = Tensor(np.ones([args.batch_size, cfg["num_channels"], args.height, args.width]).astype(np.float32)) input_data = Tensor(np.ones([config.batch_size, config.num_channels, config.height, \
export(net, input_data, file_name=args.file_name, file_format=args.file_format) config.width]).astype(np.float32))
export(net, input_data, file_name=config.file_name, file_format=config.file_format)

View File

@ -14,11 +14,10 @@
# ============================================================================ # ============================================================================
"""unet 310 infer.""" """unet 310 infer."""
import os import os
import argparse
import cv2 import cv2
import numpy as np import numpy as np
from src.config import cfg_unet from src.model_utils.config import config
class dice_coeff(): class dice_coeff():
def __init__(self): def __init__(self):
@ -38,20 +37,20 @@ class dice_coeff():
if b != 1: if b != 1:
raise ValueError('Batch size should be 1 when in evaluation.') raise ValueError('Batch size should be 1 when in evaluation.')
y = y.reshape((h, w, c)) y = y.reshape((h, w, c))
if cfg_unet["eval_activate"].lower() == "softmax": if config.eval_activate.lower() == "softmax":
y_softmax = np.squeeze(inputs[0][0], axis=0) y_softmax = np.squeeze(inputs[0][0], axis=0)
if cfg_unet["eval_resize"]: if config.eval_resize:
y_pred = [] y_pred = []
for m in range(cfg_unet["num_classes"]): for m in range(config.num_classes):
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255) y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255)
y_pred = np.stack(y_pred, axis=-1) y_pred = np.stack(y_pred, axis=-1)
else: else:
y_pred = y_softmax y_pred = y_softmax
elif cfg_unet["eval_activate"].lower() == "argmax": elif config.eval_activate.lower() == "argmax":
y_argmax = np.squeeze(inputs[0][1], axis=0) y_argmax = np.squeeze(inputs[0][1], axis=0)
y_pred = [] y_pred = []
for n in range(cfg_unet["num_classes"]): for n in range(config.num_classes):
if cfg_unet["eval_resize"]: if config.eval_resize:
y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST)) y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST))
else: else:
y_pred.append(np.float32(y_argmax == n)) y_pred.append(np.float32(y_argmax == n))
@ -73,25 +72,13 @@ class dice_coeff():
raise RuntimeError('Total samples num must not be 0.') raise RuntimeError('Total samples num must not be 0.')
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
def get_args():
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='data directory')
parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/',
help='infer result path')
return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
args = get_args() rst_path = config.rst_path
rst_path = args.rst_path
metrics = dice_coeff() metrics = dice_coeff()
if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": if config.dataset == "Cell_nuclei":
img_size = tuple(cfg_unet['img_size']) img_size = tuple(config.img_size)
for i, bin_name in enumerate(os.listdir('./preprocess_Result/')): for i, bin_name in enumerate(os.listdir('./preprocess_Result/')):
f = bin_name.replace(".png", "") f = bin_name.replace(".png", "")
bin_name_softmax = f + "_0.bin" bin_name_softmax = f + "_0.bin"
@ -100,7 +87,7 @@ if __name__ == '__main__':
file_name_arg = rst_path + bin_name_argmax file_name_arg = rst_path + bin_name_argmax
softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2) softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2)
argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96) argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96)
mask = cv2.imread(os.path.join(args.data_url, f, "mask.png"), cv2.IMREAD_GRAYSCALE) mask = cv2.imread(os.path.join(config.data_path, f, "mask.png"), cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, img_size) mask = cv2.resize(mask, img_size)
mask = mask.astype(np.float32) / 255 mask = mask.astype(np.float32) / 255
mask = (mask > 0.5).astype(np.int) mask = (mask > 0.5).astype(np.int)

View File

@ -13,19 +13,18 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""unet 310 infer preprocess dataset""" """unet 310 infer preprocess dataset"""
import argparse
import os import os
import numpy as np import numpy as np
import cv2 import cv2
from src.data_loader import create_dataset from src.data_loader import create_dataset
from src.config import cfg_unet from src.model_utils.config import config
def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None): def preprocess_dataset(data_dir, result_path, cross_valid_ind=1):
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=config.crop,
img_size=cfg['img_size']) img_size=config.img_size)
labels_list = [] labels_list = []
for i, data in enumerate(valid_dataset): for i, data in enumerate(valid_dataset):
@ -87,21 +86,9 @@ class CellNucleiDataset:
return len(self.val_ids) return len(self.val_ids)
def get_args():
parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='data directory')
parser.add_argument('-p', '--result_path', dest='result_path', type=str, default='./preprocess_Result/',
help='result path')
return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
args = get_args() if config.dataset == "Cell_nuclei":
cell_dataset = CellNucleiDataset(config.data_path, 1, config.result_path, False, 0.8)
if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei":
cell_dataset = CellNucleiDataset(args.data_url, 1, args.result_path, False, 0.8)
else: else:
preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, preprocess_dataset(data_dir=config.data_path, cross_valid_ind=config.cross_valid_ind,
result_path=args.result_path) result_path=config.result_path)

View File

@ -17,10 +17,9 @@ Preprocess dataset.
Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`. Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
""" """
import os import os
import argparse
import cv2 import cv2
import numpy as np import numpy as np
from model_zoo.official.cv.unet.src.config import cfg_unet from model_zoo.official.cv.unet.src.model_utils.config import config
def annToMask(ann, height, width): def annToMask(ann, height, width):
"""Convert annotation to RLE and then to binary mask.""" """Convert annotation to RLE and then to binary mask."""
@ -107,32 +106,27 @@ def preprocess_coco_dataset(param_dict):
cv2.imwrite(os.path.join(save_dir, img_name, "image.png"), img) cv2.imwrite(os.path.join(save_dir, img_name, "image.png"), img)
cv2.imwrite(os.path.join(save_dir, img_name, "mask.png"), mask) cv2.imwrite(os.path.join(save_dir, img_name, "mask.png"), mask)
def preprocess_dataset(cfg, data_dir): def preprocess_dataset(data_dir):
"""Select preprocess function.""" """Select preprocess function."""
if cfg['dataset'].lower() == "cell_nuclei": if config.dataset.lower() == "cell_nuclei":
preprocess_cell_nuclei_dataset({"data_dir": data_dir}) preprocess_cell_nuclei_dataset({"data_dir": data_dir})
elif cfg['dataset'].lower() == "coco": elif config.dataset.lower() == "coco":
if 'split' in cfg and cfg['split'] == 1.0: if config.split == 1.0:
train_data_path = os.path.join(data_dir, "train") train_data_path = os.path.join(data_dir, "train")
val_data_path = os.path.join(data_dir, "val") val_data_path = os.path.join(data_dir, "val")
train_param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"], train_param_dict = {"anno_json": config.anno_json, "coco_classes": config.coco_classes,
"coco_dir": cfg["coco_dir"], "save_dir": train_data_path} "coco_dir": config.coco_dir, "save_dir": train_data_path}
preprocess_coco_dataset(train_param_dict) preprocess_coco_dataset(train_param_dict)
val_param_dict = {"anno_json": cfg["val_anno_json"], "coco_classes": cfg["coco_classes"], val_param_dict = {"anno_json": config.val_anno_json, "coco_classes": config.coco_classes,
"coco_dir": cfg["val_coco_dir"], "save_dir": val_data_path} "coco_dir": config.val_coco_dir, "save_dir": val_data_path}
preprocess_coco_dataset(val_param_dict) preprocess_coco_dataset(val_param_dict)
else: else:
param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"], param_dict = {"anno_json": config.anno_json, "coco_classes": config.coco_classes,
"coco_dir": cfg["coco_dir"], "save_dir": data_dir} "coco_dir": config.coco_dir, "save_dir": data_dir}
preprocess_coco_dataset(param_dict) preprocess_coco_dataset(param_dict)
else: else:
raise ValueError("Not support dataset mode {}".format(cfg['dataset'])) raise ValueError("Not support dataset mode {}".format(config.dataset))
print("========== end preprocess dataset ==========") print("========== end preprocess dataset ==========")
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', preprocess_dataset(config.data_path)
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='save data directory')
args = parser.parse_args()
preprocess_dataset(cfg_unet, args.data_url)

View File

@ -22,13 +22,13 @@ get_real_path() {
fi fi
} }
if [ $# != 2 ] if [ $# != 3 ]
then then
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]"
echo "Please run the script as: " echo "Please run the script as: "
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]"
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data" echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data /absolute/path/to/config"
echo "==============================================================================================================" echo "=============================================================================================================="
exit 1 exit 1
fi fi
@ -36,6 +36,7 @@ PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export HCCL_CONNECT_TIMEOUT=600 export HCCL_CONNECT_TIMEOUT=600
export RANK_SIZE=8 export RANK_SIZE=8
DATASET=$(get_real_path $2) DATASET=$(get_real_path $2)
CONFIG_PATH=$(get_real_path $3)
export RANK_TABLE_FILE=$(get_real_path $1) export RANK_TABLE_FILE=$(get_real_path $1)
for((i=0;i<RANK_SIZE;i++)) for((i=0;i<RANK_SIZE;i++))
do do
@ -52,7 +53,9 @@ do
python3 ${PROJECT_DIR}/../train.py \ python3 ${PROJECT_DIR}/../train.py \
--run_distribute=True \ --run_distribute=True \
--data_url=$DATASET > log.txt 2>&1 & --data_path=$DATASET \
--config_path=$CONFIG_PATH \
--output_path './output' > log.txt 2>&1 &
cd ../ cd ../
done done

View File

@ -22,23 +22,24 @@ get_real_path() {
fi fi
} }
if [ $# != 2 ] && [ $# != 3 ] if [ $# != 3 ] && [ $# != 4 ]
then then
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the script as: " echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [DEVICE_ID](option, default is 0)" echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [CONFIG_PATH] [DEVICE_ID](option, default is 0)"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ 0" echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ /path/to/config/ 0"
echo "==============================================================================================================" echo "=============================================================================================================="
exit 1 exit 1
fi fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export DEVICE_ID=0 export DEVICE_ID=0
if [ $# != 2 ] if [ $# != 3 ]
then then
export DEVICE_ID=$3 export DEVICE_ID=$4
fi fi
DATASET=$(get_real_path $1) DATASET=$(get_real_path $1)
CHECKPOINT=$(get_real_path $2) CHECKPOINT=$(get_real_path $2)
CONFIG_PATH=$(get_real_path $3)
echo "========== start run evaluation ===========" echo "========== start run evaluation ==========="
echo "please get log at eval.log" echo "please get log at eval.log"
python ${PROJECT_DIR}/../eval.py --data_url=$DATASET --ckpt_path=$CHECKPOINT > eval.log 2>&1 & python ${PROJECT_DIR}/../eval.py --data_path=$DATASET --checkpoint_file_path=$CHECKPOINT --config_path=$CONFIG_PATH > eval.log 2>&1 &

View File

@ -22,23 +22,24 @@ get_real_path() {
fi fi
} }
if [ $# != 1 ] && [ $# != 2 ] if [ $# != 2 ] && [ $# != 3 ]
then then
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the script as: " echo "Please run the script as: "
echo "bash scripts/run_standalone_train.sh [DATASET] [DEVICE_ID](option, default is 0)" echo "bash scripts/run_standalone_train.sh [DATASET] [CONFIG_PATH] [DEVICE_ID](option, default is 0)"
echo "for example: bash run_standalone_train.sh /path/to/data/ 0" echo "for example: bash run_standalone_train.sh /path/to/data/ /path/to/config/ 0"
echo "==============================================================================================================" echo "=============================================================================================================="
exit 1 exit 1
fi fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export DEVICE_ID=0 export DEVICE_ID=0
if [ $# != 1 ] if [ $# != 2 ]
then then
export DEVICE_ID=$2 export DEVICE_ID=$3
fi fi
DATASET=$(get_real_path $1) DATASET=$(get_real_path $1)
CONFIG_PATH=$(get_real_path $2)
echo "========== start run training ===========" echo "========== start run training ==========="
echo "please get log at train.log" echo "please get log at train.log"
python ${PROJECT_DIR}/../train.py --data_url=$DATASET > train.log 2>&1 & python ${PROJECT_DIR}/../train.py --data_path=$DATASET --config_path=$CONFIG_PATH --output './output'> train.log 2>&1 &

View File

@ -1,178 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cfg_unet_medical = {
'model': 'unet_medical',
'crop': [388 / 572, 388 / 572],
'img_size': [572, 572],
'lr': 0.0001,
'epochs': 400,
'repeat': 400,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 1,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['outc.weight', 'outc.bias'],
'eval_activate': 'Softmax',
'eval_resize': False
}
cfg_unet_nested = {
'model': 'unet_nested',
'crop': None,
'img_size': [576, 576],
'lr': 0.0001,
'epochs': 400,
'repeat': 400,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 1,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'use_bn': True,
'use_ds': True,
'use_deconv': True,
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'],
'eval_activate': 'Softmax',
'eval_resize': False
}
cfg_unet_nested_cell = {
'model': 'unet_nested',
'dataset': 'Cell_nuclei',
'crop': None,
'img_size': [96, 96],
'lr': 3e-4,
'epochs': 200,
'repeat': 10,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 3,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'use_bn': True,
'use_ds': True,
'use_deconv': True,
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'],
'eval_activate': 'Softmax',
'eval_resize': False
}
cfg_unet_simple = {
'model': 'unet_simple',
'crop': None,
'img_size': [576, 576],
'lr': 0.0001,
'epochs': 400,
'repeat': 400,
'distribute_epochs': 2400,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 1,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ["final.weight"],
'eval_activate': 'Softmax',
'eval_resize': False
}
cfg_unet_simple_coco = {
'model': 'unet_simple',
'dataset': 'COCO',
'split': 1.0,
'img_size': [512, 512],
'lr': 3e-4,
'epochs': 80,
'repeat': 1,
'distribute_epochs': 120,
'cross_valid_ind': 1,
'batchsize': 16,
'num_channels': 3,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ["final.weight"],
'eval_activate': 'Softmax',
'eval_resize': False,
'num_classes': 81,
'coco_classes': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
# change the following settings to real path
'anno_json': '/data/coco2017/annotations/instances_train2017.json',
'val_anno_json': '/data/coco2017/annotations/instances_val2017.json',
'coco_dir': '/data/coco2017/train2017',
'val_coco_dir': '/data/coco2017/val2017'
}
cfg_unet = cfg_unet_simple
if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']:
print("ISBI dataset not support resize to original image size when in evaluation.")
cfg_unet['eval_resize'] = False

View File

@ -122,8 +122,8 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False) ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False)
ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False) ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False)
if do_crop: if do_crop != "None":
resize_size = [int(img_size[x] * do_crop[x]) for x in range(len(img_size))] resize_size = [int(img_size[x] * do_crop[x] / 572) for x in range(len(img_size))]
else: else:
resize_size = img_size resize_size = img_size
c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR) c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR)
@ -146,7 +146,7 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
train_ds = train_ds.map(input_columns="image", operations=c_resize_op) train_ds = train_ds.map(input_columns="image", operations=c_resize_op)
train_ds = train_ds.map(input_columns="mask", operations=c_resize_op) train_ds = train_ds.map(input_columns="mask", operations=c_resize_op)
if do_crop: if do_crop != "None":
train_ds = train_ds.map(input_columns="mask", operations=c_center_crop) train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process post_process = data_post_process
train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process) train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process)
@ -157,7 +157,7 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask) valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
valid_ds = ds.zip((valid_image_ds, valid_mask_ds)) valid_ds = ds.zip((valid_image_ds, valid_mask_ds))
valid_ds = valid_ds.project(columns=["image", "mask"]) valid_ds = valid_ds.project(columns=["image", "mask"])
if do_crop: if do_crop != "None":
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop) valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process post_process = data_post_process
valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process) valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process)

View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
_config_path = "./unet_simple_config.yaml"
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="unet_simple_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
else:
raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../unet_simple_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser, default, helper, path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -15,12 +15,12 @@
"""Device adapter for ModelArts""" """Device adapter for ModelArts"""
from utils.config import config from src.model_utils.config import config
if config.enable_modelarts: if config.enable_modelarts:
from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else: else:
from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [ __all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id" "get_device_id", "get_device_num", "get_rank_id", "get_job_id"

View File

@ -18,7 +18,7 @@
import os import os
import functools import functools
from mindspore import context from mindspore import context
from utils.config import config from src.model_utils.config import config
_global_sync_count = 0 _global_sync_count = 0

View File

@ -21,6 +21,7 @@ from mindspore import nn
from mindspore.ops import operations as ops from mindspore.ops import operations as ops
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from src.model_utils.config import config
class UnetEval(nn.Cell): class UnetEval(nn.Cell):
""" """
@ -63,10 +64,9 @@ def apply_eval(eval_param_dict):
class dice_coeff(nn.Metric): class dice_coeff(nn.Metric):
"""Unet Metric, return dice coefficient and IOU.""" """Unet Metric, return dice coefficient and IOU."""
def __init__(self, cfg_unet, print_res=True): def __init__(self, print_res=True):
super(dice_coeff, self).__init__() super(dice_coeff, self).__init__()
self.clear() self.clear()
self.cfg_unet = cfg_unet
self.print_res = print_res self.print_res = print_res
def clear(self): def clear(self):
@ -84,20 +84,20 @@ class dice_coeff(nn.Metric):
if b != 1: if b != 1:
raise ValueError('Batch size should be 1 when in evaluation.') raise ValueError('Batch size should be 1 when in evaluation.')
y = y.reshape((h, w, c)) y = y.reshape((h, w, c))
if self.cfg_unet["eval_activate"].lower() == "softmax": if config.eval_activate.lower() == "softmax":
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
if self.cfg_unet["eval_resize"]: if config.eval_resize:
y_pred = [] y_pred = []
for i in range(self.cfg_unet["num_classes"]): for i in range(config.num_classes):
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
y_pred = np.stack(y_pred, axis=-1) y_pred = np.stack(y_pred, axis=-1)
else: else:
y_pred = y_softmax y_pred = y_softmax
elif self.cfg_unet["eval_activate"].lower() == "argmax": elif config.eval_activate.lower() == "argmax":
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
y_pred = [] y_pred = []
for i in range(self.cfg_unet["num_classes"]): for i in range(config.num_classes):
if self.cfg_unet["eval_resize"]: if config.eval_resize:
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
else: else:
y_pred.append(np.float32(y_argmax == i)) y_pred.append(np.float32(y_argmax == i))

View File

@ -14,14 +14,12 @@
# ============================================================================ # ============================================================================
import os import os
import argparse
import logging import logging
import ast
import mindspore import mindspore
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Model, context from mindspore import Model, context
from mindspore.communication.management import init, get_group_size, get_rank from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
@ -31,133 +29,110 @@ from src.unet_nested import NestedUNet, UNet
from src.data_loader import create_dataset, create_multi_class_dataset from src.data_loader import create_dataset, create_multi_class_dataset
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff
from src.config import cfg_unet
from src.eval_callback import EvalCallBack from src.eval_callback import EvalCallBack
device_id = int(os.getenv('DEVICE_ID')) from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_rank_id, get_device_num
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
mindspore.set_seed(1) mindspore.set_seed(1)
def train_net(args_opt, @moxing_wrapper()
cross_valid_ind=1, def train_net(cross_valid_ind=1,
epochs=400, epochs=400,
batch_size=16, batch_size=16,
lr=0.0001, lr=0.0001):
cfg=None):
rank = 0 rank = 0
group_size = 1 group_size = 1
data_dir = args_opt.data_url data_dir = config.data_path
run_distribute = args_opt.run_distribute run_distribute = config.run_distribute
if run_distribute: if run_distribute:
init() init()
group_size = get_group_size() group_size = get_device_num()
rank = get_rank() rank = get_rank_id()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size, device_num=group_size,
gradients_mean=False) gradients_mean=False)
need_slice = False need_slice = False
if cfg['model'] == 'unet_medical': if config.model_name == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes)
elif cfg['model'] == 'unet_nested': elif config.model_name == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv,
use_bn=cfg['use_bn'], use_ds=cfg['use_ds']) use_bn=config.use_bn, use_ds=config.use_ds)
need_slice = cfg['use_ds'] need_slice = config.use_ds
elif cfg['model'] == 'unet_simple': elif config.model_name == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) net = UNet(in_channel=config.num_channels, n_class=config.num_classes)
else: else:
raise ValueError("Unsupported model: {}".format(cfg['model'])) raise ValueError("Unsupported model: {}".format(config.model_name))
if cfg['resume']: if config.resume:
param_dict = load_checkpoint(cfg['resume_ckpt']) param_dict = load_checkpoint(config.resume_ckpt)
if cfg['transfer_training']: if config.transfer_training:
filter_checkpoint_parameter_by_list(param_dict, cfg['filter_weight']) filter_checkpoint_parameter_by_list(param_dict, config.filter_weight)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
if 'use_ds' in cfg and cfg['use_ds']: if hasattr(config, "use_ds") and config.use_ds:
criterion = MultiCrossEntropyWithLogits() criterion = MultiCrossEntropyWithLogits()
else: else:
criterion = CrossEntropyWithLogits() criterion = CrossEntropyWithLogits()
if 'dataset' in cfg and cfg['dataset'] != "ISBI": if hasattr(config, "dataset") and config.dataset != "ISBI":
repeat = cfg['repeat'] if 'repeat' in cfg else 1
split = cfg['split'] if 'split' in cfg else 0.8
dataset_sink_mode = True dataset_sink_mode = True
per_print_times = 0 per_print_times = 0
train_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], repeat, batch_size, repeat = config.repeat if hasattr(config, "repeat") else 1
num_classes=cfg['num_classes'], is_train=True, augment=True, split = config.split if hasattr(config, "split") else 0.8
train_dataset = create_multi_class_dataset(data_dir, config.image_size, repeat, batch_size,
num_classes=config.num_classes, is_train=True, augment=True,
split=split, rank=rank, group_size=group_size, shuffle=True) split=split, rank=rank, group_size=group_size, shuffle=True)
valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1, valid_dataset = create_multi_class_dataset(data_dir, config.image_size, 1, 1,
num_classes=cfg['num_classes'], is_train=False, num_classes=config.num_classes, is_train=False,
eval_resize=cfg["eval_resize"], split=split, eval_resize=config.eval_resize, split=split,
python_multiprocessing=False, shuffle=False) python_multiprocessing=False, shuffle=False)
else: else:
repeat = cfg['repeat'] repeat = config.repeat
dataset_sink_mode = False dataset_sink_mode = False
per_print_times = 1 per_print_times = 1
train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
run_distribute, cfg["crop"], cfg['img_size']) run_distribute, config.crop, config.image_size)
train_data_size = train_dataset.get_dataset_size() train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size) print("dataset length is:", train_data_size)
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
keep_checkpoint_max=cfg['keep_checkpoint_max']) keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']), ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(config.model_name),
directory='./ckpt_{}/'.format(device_id), directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id),
config=ckpt_config) config=ckpt_config)
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'], optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay,
loss_scale=cfg['loss_scale']) loss_scale=config.loss_scale)
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False) loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.FixedLossScaleManager, False)
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
print("============== Starting Training ==============") print("============== Starting Training ==============")
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb] callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
if args_opt.run_eval: if config.run_eval:
eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(), eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(),
metrics={"dice_coeff": dice_coeff(cfg_unet, False)}) metrics={"dice_coeff": dice_coeff(False)})
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": args_opt.eval_metrics} eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
ckpt_directory='./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt", ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt",
metrics_name=args_opt.eval_metrics) metrics_name=config.eval_metrics)
callbacks.append(eval_cb) callbacks.append(eval_cb)
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode) model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
print("============== End Training ==============") print("============== End Training ==============")
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='data directory')
parser.add_argument('-t', '--run_distribute', type=ast.literal_eval,
default=False, help='Run distribute, default: false.')
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
help="Run evaluation when training, default is False.")
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
help="Save best checkpoint when run_eval is True, default is True.")
parser.add_argument("--eval_start_epoch", type=int, default=0,
help="Evaluation start epoch when run_eval is True, default is 0.")
parser.add_argument("--eval_interval", type=int, default=1,
help="Evaluation interval when run_eval is True, default is 1.")
parser.add_argument("--eval_metrics", type=str, default="dice_coeff", choices=("dice_coeff", "iou"),
help="Evaluation metrics when run_eval is True, support [dice_coeff, iou], "
"default is dice_coeff.")
return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
print("Training setting:", args)
epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs'] epoch_size = config.epochs if not config.run_distribute else config.distribute_epochs
train_net(args_opt=args, train_net(cross_valid_ind=config.cross_valid_ind,
cross_valid_ind=cfg_unet['cross_valid_ind'],
epochs=epoch_size, epochs=epoch_size,
batch_size=cfg_unet['batchsize'], batch_size=config.batch_size,
lr=cfg_unet['lr'], lr=config.lr)
cfg=cfg_unet)

View File

@ -0,0 +1,67 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_target: 'Ascend'
enable_profiling: False
# ==============================================================================
# Training options
model_name: 'unet_medical'
run_eval: False
run_distribute: False
crop: [388, 388]
image_size : [572, 572]
lr: 0.0001
epochs: 400
repeat: 400
distribute_epochs: 1600
batch_size: 16
cross_valid_ind: 1
num_classes: 2
num_channels: 1
weight_decay: 0.0005
loss_scale: 1024.0
FixedLossScaleManager: 1024.0
resume: False
resume_ckpt: './'
transfer_training: False
filter_weight: ['outc.weight', 'outc.bias']
#Eval options
keep_checkpoint_max: 10
eval_activate: 'Softmax'
eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_medical_adam-4-75.ckpt'
rst_path: './result_Files/'
# Export options
width: 572
height: 572
file_name: unet
file_format: AIR
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -0,0 +1,71 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_target: 'Ascend'
enable_profiling: False
# ==============================================================================
# Training options
model_name: 'unet_nested'
run_eval: False
run_distribute: False
dataset: 'Cell_nuclei'
crop: None
image_size : [96, 96]
lr: 0.0003
epochs: 200
repeat: 10
distribute_epochs: 1600
batch_size: 16
cross_valid_ind: 1
num_classes: 2
num_channels: 3
weight_decay: 0.0005
loss_scale: 1024.0
FixedLossScaleManager: 1024.0
use_ds: False
use_bn: False
use_deconv: True
resume: False
resume_ckpt: './'
transfer_training: False
filter_weight: ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
#Eval options
keep_checkpoint_max: 10
eval_activate: 'Softmax'
eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_nested_adam-4-75.ckpt'
rst_path: './result_Files/'
# Export options
width: 572
height: 572
file_name: unet
file_format: AIR
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -0,0 +1,71 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_target: 'Ascend'
enable_profiling: False
# ==============================================================================
# Training options
model_name: 'unet_nested'
run_eval: False
run_distribute: False
crop: None
image_size : [576, 576]
lr: 0.0001
epochs: 400
repeat: 400
distribute_epochs: 1600
batch_size: 16
cross_valid_ind: 1
num_classes: 2
num_channels: 1
weight_decay: 0.0005
loss_scale: 1024.0
FixedLossScaleManager: 1024.0
use_ds: True
use_bn: True
use_deconv: True
resume: False
resume_ckpt: './'
transfer_training: False
filter_weight: ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
#Eval options
keep_checkpoint_max: 10
eval_activate: 'Softmax'
eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_nested_adam-4-75.ckpt'
rst_path: './result_Files/'
# Export options
width: 572
height: 572
file_name: unet
file_format: AIR
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -0,0 +1,91 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_target: 'Ascend'
enable_profiling: False
# ==============================================================================
# Training options
model_name: 'unet_simple'
run_eval: False
run_distribute: False
dataset: 'COCO'
split : 1.0
image_size : [512, 512]
lr: 0.0001
epochs: 80
repeat: 1
distribute_epochs: 120
batch_size: 16
cross_valid_ind: 1
num_classes: 81
num_channels: 3
weight_decay: 0.0005
loss_scale: 1024.0
FixedLossScaleManager: 1024.0
resume: False
resume_ckpt: './'
transfer_training: False
filter_weight: ['final1.weight']
coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
anno_json: '/data/coco2017/annotations/instances_train2017.json'
val_anno_json: '/data/coco2017/annotations/instances_val2017.json'
coco_dir: '/data/coco2017/train2017'
val_coco_dir: '/data/coco2017/val2017'
#Eval options
eval_metrics: "dice_coeff"
eval_start_epoch: 0
eval_interval: 1
keep_checkpoint_max: 10
eval_activate: 'Softmax'
eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_simple_adam-4-75.ckpt'
rst_path: './result_Files/'
# Export options
width: 572
height: 572
file_name: unet
file_format: AIR
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -0,0 +1,68 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_target: 'Ascend'
enable_profiling: False
# ==============================================================================
# Training options
model_name: 'unet_simple'
run_eval: False
run_distribute: False
crop: None
image_size : [576, 576]
lr: 0.0001
epochs: 400
repeat: 400
distribute_epochs: 2400
batch_size: 16
cross_valid_ind: 1
num_classes: 2
num_channels: 1
weight_decay: 0.0005
loss_scale: 1024.0
FixedLossScaleManager: 1024.0
resume: False
resume_ckpt: './'
transfer_training: False
filter_weight: ['final1.weight']
#Eval options
keep_checkpoint_max: 10
eval_activate: 'Softmax'
eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_simple_adam-4-75.ckpt'
rst_path: './result_Files/'
# Export options
width: 572
height: 572
file_name: unet
file_format: AIR
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -16,7 +16,7 @@
if [ $# -ne 2 ] if [ $# -ne 2 ]
then then
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]" echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [DATA_PATH]"
exit 1 exit 1
fi fi

View File

@ -18,14 +18,14 @@ if [ $# != 2 ]
then then
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the script as: " echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]" echo "bash scripts/run_standalone_eval.sh [DATA_PATH] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/" echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "==============================================================================================================" echo "=============================================================================================================="
fi fi
if [ $# != 2 ] if [ $# != 2 ]
then then
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]" echo "Usage: sh run_eval_ascend.sh [DATA_PATH] [CHECKPOINT_PATH]"
exit 1 exit 1
fi fi

View File

@ -16,7 +16,7 @@
if [ $# -ne 1 ] if [ $# -ne 1 ]
then then
echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]" echo "Usage: sh run_distribute_train_ascend.sh [DATA_PATH]"
exit 1 exit 1
fi fi

View File

@ -100,26 +100,26 @@ If you want to run in modelarts, please check the official documentation of [mod
```bash ```bash
├── model_zoo ├── model_zoo
├── README.md // descriptions about all the models ├── README.md // descriptions about all the models
├── textcnn ├── textcnn
├── README.md // descriptions about textcnn ├── README.md // descriptions about textcnn
├──scripts ├──scripts
│ ├── run_train.sh // shell script for distributed on Ascend │ ├── run_train.sh // shell script for distributed on Ascend
│ ├── run_eval.sh // shell script for evaluation on Ascend │ ├── run_eval.sh // shell script for evaluation on Ascend
├── src ├── src
│ ├── dataset.py // Processing dataset │ ├── dataset.py // Processing dataset
│ ├── textcnn.py // textcnn architecture │ ├── textcnn.py // textcnn architecture
├── utils ├── utils
│ ├──device_adapter.py // device adapter │ ├──device_adapter.py // device adapter
│ ├──local_adapter.py // local adapter │ ├──local_adapter.py // local adapter
│ ├──moxing_adapter.py // moxing adapter │ ├──moxing_adapter.py // moxing adapter
│ ├── config.py // parameter analysis │ ├──config.py // parameter analysis
├── mr_config.yaml // parameter configuration ├── mr_config.yaml // parameter configuration
├── sst2_config.yaml // parameter configuration ├── sst2_config.yaml // parameter configuration
├── subj_config.yaml // parameter configuration ├── subj_config.yaml // parameter configuration
├── train.py // training script ├── train.py // training script
├── eval.py // evaluation script ├── eval.py // evaluation script
├── export.py // export checkpoint to other format file ├── export.py // export checkpoint to other format file
``` ```
## [Script Parameters](#contents) ## [Script Parameters](#contents)

View File

@ -22,9 +22,9 @@ from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from utils.moxing_adapter import moxing_wrapper from src.model_utils.moxing_adapter import moxing_wrapper
from utils.device_adapter import get_device_id from src.model_utils.device_adapter import get_device_id
from utils.config import config from src.model_utils.config import config
from src.textcnn import TextCNN from src.textcnn import TextCNN
from src.dataset import MovieReview, SST2, Subjectivity from src.dataset import MovieReview, SST2, Subjectivity

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from utils.config import config from src.model_utils.config import config
from src.textcnn import TextCNN from src.textcnn import TextCNN
from src.dataset import MovieReview, SST2, Subjectivity from src.dataset import MovieReview, SST2, Subjectivity

View File

@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt'
word_len: 51 word_len: 51
vec_length: 40 vec_length: 40
base_lr: 1e-3 base_lr: 1e-3
label_dir: ""
result_dir: ""
result_path: './preprocess_Result/'
# Export options # Export options
device_id: 0 device_id: 0

View File

@ -16,35 +16,21 @@
##############postprocess################# ##############postprocess#################
""" """
import os import os
import argparse
import numpy as np import numpy as np
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from src.config import cfg_mr, cfg_subj, cfg_sst2 from src.model_utils.config import config
parser = argparse.ArgumentParser(description='postprocess')
parser.add_argument('--label_dir', type=str, default="", help='label data dir')
parser.add_argument('--result_dir', type=str, default="", help="infer result dir")
parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2'])
args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
if args.dataset == 'MR':
cfg = cfg_mr
elif args.dataset == 'SUBJ':
cfg = cfg_subj
elif args.dataset == 'SST2':
cfg = cfg_sst2
file_prefix = 'textcnn_bs' + str(cfg.batch_size) + '_' file_prefix = 'textcnn_bs' + str(config.batch_size) + '_'
metric = Accuracy() metric = Accuracy()
metric.clear() metric.clear()
label_list = np.load(args.label_dir, allow_pickle=True) label_list = np.load(config.label_dir, allow_pickle=True)
for idx, label in enumerate(label_list): for idx, label in enumerate(label_list):
pred = np.fromfile(os.path.join(args.result_dir, file_prefix + str(idx) + '_0.bin'), np.float32) pred = np.fromfile(os.path.join(config.result_dir, file_prefix + str(idx) + '_0.bin'), np.float32)
pred = pred.reshape(cfg.batch_size, int(pred.shape[0]/cfg.batch_size)) pred = pred.reshape(config.batch_size, int(pred.shape[0]/config.batch_size))
metric.update(pred, label) metric.update(pred, label)
accuracy = metric.eval() accuracy = metric.eval()
print("accuracy: ", accuracy) print("accuracy: ", accuracy)

View File

@ -15,37 +15,28 @@
""" """
##############preprocess textcnn example on movie review################# ##############preprocess textcnn example on movie review#################
""" """
import argparse
import os import os
import numpy as np import numpy as np
from src.config import cfg_mr, cfg_subj, cfg_sst2 from src.model_utils.config import config
from src.dataset import MovieReview, SST2, Subjectivity from src.dataset import MovieReview, SST2, Subjectivity
parser = argparse.ArgumentParser(description='TextCNN')
parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2'])
parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
args_opt = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.dataset == 'MR': if config.dataset == 'MR':
cfg = cfg_mr instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9)
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) elif config.dataset == 'SUBJ':
elif args_opt.dataset == 'SUBJ': instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9)
cfg = cfg_subj elif config.dataset == 'SST2':
instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9)
elif args_opt.dataset == 'SST2':
cfg = cfg_sst2
instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
dataset = instance.create_test_dataset(batch_size=cfg.batch_size) dataset = instance.create_test_dataset(batch_size=config.batch_size)
img_path = os.path.join(args_opt.result_path, "00_data") img_path = os.path.join(config.result_path, "00_data")
os.makedirs(img_path) os.makedirs(img_path)
label_list = [] label_list = []
for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)): for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
file_name = "textcnn_bs" + str(cfg.batch_size) + "_" + str(i) + ".bin" file_name = "textcnn_bs" + str(config.batch_size) + "_" + str(i) + ".bin"
file_path = img_path + "/" + file_name file_path = img_path + "/" + file_name
data['data'].tofile(file_path) data['data'].tofile(file_path)
label_list.append(data['label']) label_list.append(data['label'])
np.save(args_opt.result_path + "label_ids.npy", label_list) np.save(config.result_path + "label_ids.npy", label_list)
print("="*20, "export bin files finished", "="*20) print("="*20, "export bin files finished", "="*20)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from src.model_utils.config import config
if config.enable_modelarts:
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from src.model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt'
word_len: 51 word_len: 51
vec_length: 40 vec_length: 40
base_lr: 5e-3 base_lr: 5e-3
label_dir: ""
result_dir: ""
result_path: './preprocess_Result/'
# Export options # Export options
device_id: 0 device_id: 0

View File

@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt'
word_len: 51 word_len: 51
vec_length: 40 vec_length: 40
base_lr: 8e-4 base_lr: 8e-4
label_dir: ""
result_dir: ""
result_path: './preprocess_Result/'
# Export options # Export options
device_id: 0 device_id: 0

View File

@ -26,9 +26,9 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from utils.moxing_adapter import moxing_wrapper from src.model_utils.moxing_adapter import moxing_wrapper
from utils.device_adapter import get_device_id, get_rank_id from src.model_utils.device_adapter import get_device_id, get_rank_id
from utils.config import config from src.model_utils.config import config
from src.textcnn import TextCNN from src.textcnn import TextCNN
from src.textcnn import SoftmaxCrossEntropyExpand from src.textcnn import SoftmaxCrossEntropyExpand
from src.dataset import MovieReview, SST2, Subjectivity from src.dataset import MovieReview, SST2, Subjectivity