!16400 modify unet for cloud and modify model_utils of textcnn
From: @Somnus2020 Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
844afd374b
|
@ -160,6 +160,35 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
||||||
|
|
||||||
Then you can run everything just like on ascend.
|
Then you can run everything just like on ascend.
|
||||||
|
|
||||||
|
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# run distributed training on modelarts example
|
||||||
|
# (1) First, Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on yaml file.
|
||||||
|
# Set other parameters on yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Set the config directory to "config_path=/The path of config in S3/"
|
||||||
|
# (3) Set the code directory to "/path/unet" on the website UI interface.
|
||||||
|
# (4) Set the startup file to "train.py" on the website UI interface.
|
||||||
|
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (6) Create your job.
|
||||||
|
|
||||||
|
# run evaluation on modelarts example
|
||||||
|
# (1) Copy or upload your trained model to S3 bucket.
|
||||||
|
# (2) Perform a or b.
|
||||||
|
# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file.
|
||||||
|
# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file.
|
||||||
|
# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# (3) Set the config directory to "config_path=/The path of config in S3/"
|
||||||
|
# (4) Set the code directory to "/path/unet" on the website UI interface.
|
||||||
|
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||||
|
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (7) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
## [Script Description](#contents)
|
## [Script Description](#contents)
|
||||||
|
|
||||||
### [Script and Sample Code](#contents)
|
### [Script and Sample Code](#contents)
|
||||||
|
@ -190,6 +219,16 @@ Then you can run everything just like on ascend.
|
||||||
├──__init__.py // init file
|
├──__init__.py // init file
|
||||||
├──unet_model.py // unet model
|
├──unet_model.py // unet model
|
||||||
├──unet_parts.py // unet part
|
├──unet_parts.py // unet part
|
||||||
|
├── model_utils
|
||||||
|
│ ├── config.py // parameter configuration
|
||||||
|
│ ├── device_adapter.py // device adapter
|
||||||
|
│ ├── local_adapter.py // local adapter
|
||||||
|
│ ├── moxing_adapter.py // moxing adapter
|
||||||
|
├── unet_medical_config.yaml // parameter configuration
|
||||||
|
├── unet_nested_cell_config.yaml // parameter configuration
|
||||||
|
├── unet_nested_config.yaml // parameter configuration
|
||||||
|
├── unet_simple_config.yaml // parameter configuration
|
||||||
|
├── unet_simple_coco_config.yaml // parameter configuration
|
||||||
├── train.py // training script
|
├── train.py // training script
|
||||||
├── eval.py // evaluation script
|
├── eval.py // evaluation script
|
||||||
├── export.py // export script
|
├── export.py // export script
|
||||||
|
|
|
@ -164,6 +164,38 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
||||||
|
|
||||||
然后在容器里的操作就和Ascend平台上是一样的。
|
然后在容器里的操作就和Ascend平台上是一样的。
|
||||||
|
|
||||||
|
如果要在modelarts上进行模型的训练,可以参考modelarts的官方指导文档(https://support.huaweicloud.com/modelarts/)
|
||||||
|
开始进行模型的训练和推理,具体操作如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在modelarts上使用分布式训练的示例:
|
||||||
|
# (1) 选址a或者b其中一种方式。
|
||||||
|
# a. 设置 "enable_modelarts=True" 。
|
||||||
|
# 在yaml文件上设置网络所需的参数。
|
||||||
|
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||||
|
# 在modelarts的界面上设置网络所需的参数。
|
||||||
|
# (2)设置网络配置文件的路径 "config_path=/The path of config in S3/"
|
||||||
|
# (3) 在modelarts的界面上设置代码的路径 "/path/unet"。
|
||||||
|
# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。
|
||||||
|
# (5) 在modelarts的界面上设置模型的数据路径 "Dataset path" ,
|
||||||
|
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||||
|
# (6) 开始模型的训练。
|
||||||
|
|
||||||
|
# 在modelarts上使用模型推理的示例
|
||||||
|
# (1) 把训练好的模型地方到桶的对应位置。
|
||||||
|
# (2) 选址a或者b其中一种方式。
|
||||||
|
# a. 设置 "checkpoint_file_path='/cache/checkpoint_path/model.ckpt" 在 yaml 文件.
|
||||||
|
# 设置 "checkpoint_url=/The path of checkpoint in S3/" 在 yaml 文件.
|
||||||
|
# b. 增加 "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" 参数在modearts的界面上。
|
||||||
|
# 增加 "checkpoint_url=/The path of checkpoint in S3/" 参数在modearts的界面上。
|
||||||
|
# (3) 设置网络配置文件的路径 "config_path=/The path of config in S3/"
|
||||||
|
# (4) 在modelarts的界面上设置代码的路径 "/path/unet"。
|
||||||
|
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。
|
||||||
|
# (6) 在modelarts的界面上设置模型的数据路径 "Dataset path" ,
|
||||||
|
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||||
|
# (7) 开始模型的推理。
|
||||||
|
```
|
||||||
|
|
||||||
## 脚本说明
|
## 脚本说明
|
||||||
|
|
||||||
### 脚本及样例代码
|
### 脚本及样例代码
|
||||||
|
@ -194,6 +226,16 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
||||||
├──__init__.py
|
├──__init__.py
|
||||||
├──unet_model.py // Unet++ 网络结构
|
├──unet_model.py // Unet++ 网络结构
|
||||||
├──unet_parts.py // Unet++ 子网
|
├──unet_parts.py // Unet++ 子网
|
||||||
|
├── model_utils
|
||||||
|
│ ├── config.py // 参数配置
|
||||||
|
│ ├── device_adapter.py // 设备配置
|
||||||
|
│ ├── local_adapter.py // 本地设备配置
|
||||||
|
│ ├── moxing_adapter.py // modelarts设备配置
|
||||||
|
├── unet_medical_config.yaml // 配置文件
|
||||||
|
├── unet_nested_cell_config.yaml // 配置文件
|
||||||
|
├── unet_nested_config.yaml // 配置文件
|
||||||
|
├── unet_simple_config.yaml // 配置文件
|
||||||
|
├── unet_simple_coco_config.yaml // 配置文件
|
||||||
├── train.py // 训练脚本
|
├── train.py // 训练脚本
|
||||||
├── eval.py // 推理脚本
|
├── eval.py // 推理脚本
|
||||||
├── export.py // 导出脚本
|
├── export.py // 导出脚本
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
|
||||||
import logging
|
import logging
|
||||||
from mindspore import context, Model
|
from mindspore import context, Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
@ -22,38 +21,39 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.data_loader import create_dataset, create_multi_class_dataset
|
from src.data_loader import create_dataset, create_multi_class_dataset
|
||||||
from src.unet_medical import UNetMedical
|
from src.unet_medical import UNetMedical
|
||||||
from src.unet_nested import NestedUNet, UNet
|
from src.unet_nested import NestedUNet, UNet
|
||||||
from src.config import cfg_unet
|
|
||||||
from src.utils import UnetEval, TempLoss, dice_coeff
|
from src.utils import UnetEval, TempLoss, dice_coeff
|
||||||
|
from src.model_utils.config import config
|
||||||
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
|
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv("DEVICE_ID"))
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||||
|
|
||||||
|
@moxing_wrapper()
|
||||||
def test_net(data_dir,
|
def test_net(data_dir,
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
cross_valid_ind=1,
|
cross_valid_ind=1):
|
||||||
cfg=None):
|
if config.model_name == 'unet_medical':
|
||||||
if cfg['model'] == 'unet_medical':
|
net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes)
|
||||||
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
elif config.model_name == 'unet_nested':
|
||||||
elif cfg['model'] == 'unet_nested':
|
net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv,
|
||||||
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
|
use_bn=config.use_bn, use_ds=False)
|
||||||
use_bn=cfg['use_bn'], use_ds=False)
|
elif config.model_name == 'unet_simple':
|
||||||
elif cfg['model'] == 'unet_simple':
|
net = UNet(in_channel=config.num_channels, n_class=config.num_classes)
|
||||||
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported model: {}".format(cfg['model']))
|
raise ValueError("Unsupported model: {}".format(config.model_name))
|
||||||
param_dict = load_checkpoint(ckpt_path)
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
net = UnetEval(net)
|
net = UnetEval(net)
|
||||||
if 'dataset' in cfg and cfg['dataset'] != "ISBI":
|
if hasattr(config, "dataset") and config.dataset != "ISBI":
|
||||||
split = cfg['split'] if 'split' in cfg else 0.8
|
split = config.split if hasattr(config, "dataset") else 0.8
|
||||||
valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1,
|
valid_dataset = create_multi_class_dataset(data_dir, config.image_size, 1, 1,
|
||||||
num_classes=cfg['num_classes'], is_train=False,
|
num_classes=config.num_classes, is_train=False,
|
||||||
eval_resize=cfg["eval_resize"], split=split,
|
eval_resize=config.eval_resize, split=split,
|
||||||
python_multiprocessing=False, shuffle=False)
|
python_multiprocessing=False, shuffle=False)
|
||||||
else:
|
else:
|
||||||
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
|
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
|
||||||
do_crop=cfg['crop'], img_size=cfg['img_size'])
|
do_crop=config.crop, img_size=config.image_size)
|
||||||
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(cfg_unet)})
|
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()})
|
||||||
|
|
||||||
print("============== Starting Evaluating ============")
|
print("============== Starting Evaluating ============")
|
||||||
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]
|
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]
|
||||||
|
@ -61,22 +61,8 @@ def test_net(data_dir,
|
||||||
print("============== Cross valid IOU is:", eval_score[1])
|
print("============== Cross valid IOU is:", eval_score[1])
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
|
|
||||||
help='data directory')
|
|
||||||
parser.add_argument('-p', '--ckpt_path', dest='ckpt_path', type=str, default='ckpt_unet_medical_adam-1_600.ckpt',
|
|
||||||
help='checkpoint path')
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||||
args = get_args()
|
test_net(data_dir=config.data_path,
|
||||||
print("Testing setting:", args)
|
ckpt_path=config.checkpoint_file_path,
|
||||||
test_net(data_dir=args.data_url,
|
cross_valid_ind=config.cross_valid_ind)
|
||||||
ckpt_path=args.ckpt_path,
|
|
||||||
cross_valid_ind=cfg_unet['cross_valid_ind'],
|
|
||||||
cfg=cfg_unet)
|
|
||||||
|
|
|
@ -13,46 +13,36 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context
|
from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context
|
||||||
|
|
||||||
from src.unet_medical.unet_model import UNetMedical
|
from src.unet_medical.unet_model import UNetMedical
|
||||||
from src.unet_nested import NestedUNet, UNet
|
from src.unet_nested import NestedUNet, UNet
|
||||||
from src.config import cfg_unet as cfg
|
|
||||||
from src.utils import UnetEval
|
from src.utils import UnetEval
|
||||||
|
from src.model_utils.config import config
|
||||||
|
from src.model_utils.device_adapter import get_device_id
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='unet export')
|
|
||||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
|
||||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
|
||||||
parser.add_argument('--width', type=int, default=572, help='input width')
|
|
||||||
parser.add_argument('--height', type=int, default=572, help='input height')
|
|
||||||
parser.add_argument("--file_name", type=str, default="unet", help="output file name.")
|
|
||||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
|
||||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
|
||||||
help="device target")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||||
if args.device_target == "Ascend":
|
if config.device_target == "Ascend":
|
||||||
context.set_context(device_id=args.device_id)
|
context.set_context(device_id=get_device_id())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if cfg['model'] == 'unet_medical':
|
if config.model == 'unet_medical':
|
||||||
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes)
|
||||||
elif cfg['model'] == 'unet_nested':
|
elif config.model == 'unet_nested':
|
||||||
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
|
net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv,
|
||||||
use_bn=cfg['use_bn'], use_ds=False)
|
use_bn=config.use_bn, use_ds=False)
|
||||||
elif cfg['model'] == 'unet_simple':
|
elif config.model == 'unet_simple':
|
||||||
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
|
net = UNet(in_channel=config.num_channels, n_class=config.num_classes)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported model: {}".format(cfg['model']))
|
raise ValueError("Unsupported model: {}".format(config.model))
|
||||||
# return a parameter dict for model
|
# return a parameter dict for model
|
||||||
param_dict = load_checkpoint(args.ckpt_file)
|
param_dict = load_checkpoint(config.checkpoint_file_path)
|
||||||
# load the parameter into net
|
# load the parameter into net
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
net = UnetEval(net)
|
net = UnetEval(net)
|
||||||
input_data = Tensor(np.ones([args.batch_size, cfg["num_channels"], args.height, args.width]).astype(np.float32))
|
input_data = Tensor(np.ones([config.batch_size, config.num_channels, config.height, \
|
||||||
export(net, input_data, file_name=args.file_name, file_format=args.file_format)
|
config.width]).astype(np.float32))
|
||||||
|
export(net, input_data, file_name=config.file_name, file_format=config.file_format)
|
||||||
|
|
|
@ -14,11 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""unet 310 infer."""
|
"""unet 310 infer."""
|
||||||
import os
|
import os
|
||||||
import argparse
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.config import cfg_unet
|
from src.model_utils.config import config
|
||||||
|
|
||||||
class dice_coeff():
|
class dice_coeff():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -38,20 +37,20 @@ class dice_coeff():
|
||||||
if b != 1:
|
if b != 1:
|
||||||
raise ValueError('Batch size should be 1 when in evaluation.')
|
raise ValueError('Batch size should be 1 when in evaluation.')
|
||||||
y = y.reshape((h, w, c))
|
y = y.reshape((h, w, c))
|
||||||
if cfg_unet["eval_activate"].lower() == "softmax":
|
if config.eval_activate.lower() == "softmax":
|
||||||
y_softmax = np.squeeze(inputs[0][0], axis=0)
|
y_softmax = np.squeeze(inputs[0][0], axis=0)
|
||||||
if cfg_unet["eval_resize"]:
|
if config.eval_resize:
|
||||||
y_pred = []
|
y_pred = []
|
||||||
for m in range(cfg_unet["num_classes"]):
|
for m in range(config.num_classes):
|
||||||
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255)
|
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255)
|
||||||
y_pred = np.stack(y_pred, axis=-1)
|
y_pred = np.stack(y_pred, axis=-1)
|
||||||
else:
|
else:
|
||||||
y_pred = y_softmax
|
y_pred = y_softmax
|
||||||
elif cfg_unet["eval_activate"].lower() == "argmax":
|
elif config.eval_activate.lower() == "argmax":
|
||||||
y_argmax = np.squeeze(inputs[0][1], axis=0)
|
y_argmax = np.squeeze(inputs[0][1], axis=0)
|
||||||
y_pred = []
|
y_pred = []
|
||||||
for n in range(cfg_unet["num_classes"]):
|
for n in range(config.num_classes):
|
||||||
if cfg_unet["eval_resize"]:
|
if config.eval_resize:
|
||||||
y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST))
|
y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST))
|
||||||
else:
|
else:
|
||||||
y_pred.append(np.float32(y_argmax == n))
|
y_pred.append(np.float32(y_argmax == n))
|
||||||
|
@ -73,25 +72,13 @@ class dice_coeff():
|
||||||
raise RuntimeError('Total samples num must not be 0.')
|
raise RuntimeError('Total samples num must not be 0.')
|
||||||
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
|
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
|
|
||||||
help='data directory')
|
|
||||||
parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/',
|
|
||||||
help='infer result path')
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = get_args()
|
rst_path = config.rst_path
|
||||||
|
|
||||||
rst_path = args.rst_path
|
|
||||||
metrics = dice_coeff()
|
metrics = dice_coeff()
|
||||||
|
|
||||||
if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei":
|
if config.dataset == "Cell_nuclei":
|
||||||
img_size = tuple(cfg_unet['img_size'])
|
img_size = tuple(config.img_size)
|
||||||
for i, bin_name in enumerate(os.listdir('./preprocess_Result/')):
|
for i, bin_name in enumerate(os.listdir('./preprocess_Result/')):
|
||||||
f = bin_name.replace(".png", "")
|
f = bin_name.replace(".png", "")
|
||||||
bin_name_softmax = f + "_0.bin"
|
bin_name_softmax = f + "_0.bin"
|
||||||
|
@ -100,7 +87,7 @@ if __name__ == '__main__':
|
||||||
file_name_arg = rst_path + bin_name_argmax
|
file_name_arg = rst_path + bin_name_argmax
|
||||||
softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2)
|
softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2)
|
||||||
argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96)
|
argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96)
|
||||||
mask = cv2.imread(os.path.join(args.data_url, f, "mask.png"), cv2.IMREAD_GRAYSCALE)
|
mask = cv2.imread(os.path.join(config.data_path, f, "mask.png"), cv2.IMREAD_GRAYSCALE)
|
||||||
mask = cv2.resize(mask, img_size)
|
mask = cv2.resize(mask, img_size)
|
||||||
mask = mask.astype(np.float32) / 255
|
mask = mask.astype(np.float32) / 255
|
||||||
mask = (mask > 0.5).astype(np.int)
|
mask = (mask > 0.5).astype(np.int)
|
||||||
|
|
|
@ -13,19 +13,18 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""unet 310 infer preprocess dataset"""
|
"""unet 310 infer preprocess dataset"""
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from src.data_loader import create_dataset
|
from src.data_loader import create_dataset
|
||||||
from src.config import cfg_unet
|
from src.model_utils.config import config
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None):
|
def preprocess_dataset(data_dir, result_path, cross_valid_ind=1):
|
||||||
|
|
||||||
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'],
|
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=config.crop,
|
||||||
img_size=cfg['img_size'])
|
img_size=config.img_size)
|
||||||
|
|
||||||
labels_list = []
|
labels_list = []
|
||||||
for i, data in enumerate(valid_dataset):
|
for i, data in enumerate(valid_dataset):
|
||||||
|
@ -87,21 +86,9 @@ class CellNucleiDataset:
|
||||||
return len(self.val_ids)
|
return len(self.val_ids)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ',
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
|
|
||||||
help='data directory')
|
|
||||||
parser.add_argument('-p', '--result_path', dest='result_path', type=str, default='./preprocess_Result/',
|
|
||||||
help='result path')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = get_args()
|
if config.dataset == "Cell_nuclei":
|
||||||
|
cell_dataset = CellNucleiDataset(config.data_path, 1, config.result_path, False, 0.8)
|
||||||
if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei":
|
|
||||||
cell_dataset = CellNucleiDataset(args.data_url, 1, args.result_path, False, 0.8)
|
|
||||||
else:
|
else:
|
||||||
preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet,
|
preprocess_dataset(data_dir=config.data_path, cross_valid_ind=config.cross_valid_ind,
|
||||||
result_path=args.result_path)
|
result_path=config.result_path)
|
||||||
|
|
|
@ -17,10 +17,9 @@ Preprocess dataset.
|
||||||
Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
|
Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import argparse
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from model_zoo.official.cv.unet.src.config import cfg_unet
|
from model_zoo.official.cv.unet.src.model_utils.config import config
|
||||||
|
|
||||||
def annToMask(ann, height, width):
|
def annToMask(ann, height, width):
|
||||||
"""Convert annotation to RLE and then to binary mask."""
|
"""Convert annotation to RLE and then to binary mask."""
|
||||||
|
@ -107,32 +106,27 @@ def preprocess_coco_dataset(param_dict):
|
||||||
cv2.imwrite(os.path.join(save_dir, img_name, "image.png"), img)
|
cv2.imwrite(os.path.join(save_dir, img_name, "image.png"), img)
|
||||||
cv2.imwrite(os.path.join(save_dir, img_name, "mask.png"), mask)
|
cv2.imwrite(os.path.join(save_dir, img_name, "mask.png"), mask)
|
||||||
|
|
||||||
def preprocess_dataset(cfg, data_dir):
|
def preprocess_dataset(data_dir):
|
||||||
"""Select preprocess function."""
|
"""Select preprocess function."""
|
||||||
if cfg['dataset'].lower() == "cell_nuclei":
|
if config.dataset.lower() == "cell_nuclei":
|
||||||
preprocess_cell_nuclei_dataset({"data_dir": data_dir})
|
preprocess_cell_nuclei_dataset({"data_dir": data_dir})
|
||||||
elif cfg['dataset'].lower() == "coco":
|
elif config.dataset.lower() == "coco":
|
||||||
if 'split' in cfg and cfg['split'] == 1.0:
|
if config.split == 1.0:
|
||||||
train_data_path = os.path.join(data_dir, "train")
|
train_data_path = os.path.join(data_dir, "train")
|
||||||
val_data_path = os.path.join(data_dir, "val")
|
val_data_path = os.path.join(data_dir, "val")
|
||||||
train_param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"],
|
train_param_dict = {"anno_json": config.anno_json, "coco_classes": config.coco_classes,
|
||||||
"coco_dir": cfg["coco_dir"], "save_dir": train_data_path}
|
"coco_dir": config.coco_dir, "save_dir": train_data_path}
|
||||||
preprocess_coco_dataset(train_param_dict)
|
preprocess_coco_dataset(train_param_dict)
|
||||||
val_param_dict = {"anno_json": cfg["val_anno_json"], "coco_classes": cfg["coco_classes"],
|
val_param_dict = {"anno_json": config.val_anno_json, "coco_classes": config.coco_classes,
|
||||||
"coco_dir": cfg["val_coco_dir"], "save_dir": val_data_path}
|
"coco_dir": config.val_coco_dir, "save_dir": val_data_path}
|
||||||
preprocess_coco_dataset(val_param_dict)
|
preprocess_coco_dataset(val_param_dict)
|
||||||
else:
|
else:
|
||||||
param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"],
|
param_dict = {"anno_json": config.anno_json, "coco_classes": config.coco_classes,
|
||||||
"coco_dir": cfg["coco_dir"], "save_dir": data_dir}
|
"coco_dir": config.coco_dir, "save_dir": data_dir}
|
||||||
preprocess_coco_dataset(param_dict)
|
preprocess_coco_dataset(param_dict)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not support dataset mode {}".format(cfg['dataset']))
|
raise ValueError("Not support dataset mode {}".format(config.dataset))
|
||||||
print("========== end preprocess dataset ==========")
|
print("========== end preprocess dataset ==========")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
|
preprocess_dataset(config.data_path)
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
|
|
||||||
help='save data directory')
|
|
||||||
args = parser.parse_args()
|
|
||||||
preprocess_dataset(cfg_unet, args.data_url)
|
|
||||||
|
|
|
@ -22,13 +22,13 @@ get_real_path() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
if [ $# != 2 ]
|
if [ $# != 3 ]
|
||||||
then
|
then
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
|
echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]"
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
|
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]"
|
||||||
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data"
|
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data /absolute/path/to/config"
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
@ -36,6 +36,7 @@ PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
export HCCL_CONNECT_TIMEOUT=600
|
export HCCL_CONNECT_TIMEOUT=600
|
||||||
export RANK_SIZE=8
|
export RANK_SIZE=8
|
||||||
DATASET=$(get_real_path $2)
|
DATASET=$(get_real_path $2)
|
||||||
|
CONFIG_PATH=$(get_real_path $3)
|
||||||
export RANK_TABLE_FILE=$(get_real_path $1)
|
export RANK_TABLE_FILE=$(get_real_path $1)
|
||||||
for((i=0;i<RANK_SIZE;i++))
|
for((i=0;i<RANK_SIZE;i++))
|
||||||
do
|
do
|
||||||
|
@ -52,7 +53,9 @@ do
|
||||||
|
|
||||||
python3 ${PROJECT_DIR}/../train.py \
|
python3 ${PROJECT_DIR}/../train.py \
|
||||||
--run_distribute=True \
|
--run_distribute=True \
|
||||||
--data_url=$DATASET > log.txt 2>&1 &
|
--data_path=$DATASET \
|
||||||
|
--config_path=$CONFIG_PATH \
|
||||||
|
--output_path './output' > log.txt 2>&1 &
|
||||||
|
|
||||||
cd ../
|
cd ../
|
||||||
done
|
done
|
||||||
|
|
|
@ -22,23 +22,24 @@ get_real_path() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
if [ $# != 2 ] && [ $# != 3 ]
|
if [ $# != 3 ] && [ $# != 4 ]
|
||||||
then
|
then
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [DEVICE_ID](option, default is 0)"
|
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [CONFIG_PATH] [DEVICE_ID](option, default is 0)"
|
||||||
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ 0"
|
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ /path/to/config/ 0"
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
if [ $# != 2 ]
|
if [ $# != 3 ]
|
||||||
then
|
then
|
||||||
export DEVICE_ID=$3
|
export DEVICE_ID=$4
|
||||||
fi
|
fi
|
||||||
DATASET=$(get_real_path $1)
|
DATASET=$(get_real_path $1)
|
||||||
CHECKPOINT=$(get_real_path $2)
|
CHECKPOINT=$(get_real_path $2)
|
||||||
|
CONFIG_PATH=$(get_real_path $3)
|
||||||
echo "========== start run evaluation ==========="
|
echo "========== start run evaluation ==========="
|
||||||
echo "please get log at eval.log"
|
echo "please get log at eval.log"
|
||||||
python ${PROJECT_DIR}/../eval.py --data_url=$DATASET --ckpt_path=$CHECKPOINT > eval.log 2>&1 &
|
python ${PROJECT_DIR}/../eval.py --data_path=$DATASET --checkpoint_file_path=$CHECKPOINT --config_path=$CONFIG_PATH > eval.log 2>&1 &
|
||||||
|
|
|
@ -22,23 +22,24 @@ get_real_path() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
if [ $# != 1 ] && [ $# != 2 ]
|
if [ $# != 2 ] && [ $# != 3 ]
|
||||||
then
|
then
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "bash scripts/run_standalone_train.sh [DATASET] [DEVICE_ID](option, default is 0)"
|
echo "bash scripts/run_standalone_train.sh [DATASET] [CONFIG_PATH] [DEVICE_ID](option, default is 0)"
|
||||||
echo "for example: bash run_standalone_train.sh /path/to/data/ 0"
|
echo "for example: bash run_standalone_train.sh /path/to/data/ /path/to/config/ 0"
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
if [ $# != 1 ]
|
if [ $# != 2 ]
|
||||||
then
|
then
|
||||||
export DEVICE_ID=$2
|
export DEVICE_ID=$3
|
||||||
fi
|
fi
|
||||||
|
|
||||||
DATASET=$(get_real_path $1)
|
DATASET=$(get_real_path $1)
|
||||||
|
CONFIG_PATH=$(get_real_path $2)
|
||||||
echo "========== start run training ==========="
|
echo "========== start run training ==========="
|
||||||
echo "please get log at train.log"
|
echo "please get log at train.log"
|
||||||
python ${PROJECT_DIR}/../train.py --data_url=$DATASET > train.log 2>&1 &
|
python ${PROJECT_DIR}/../train.py --data_path=$DATASET --config_path=$CONFIG_PATH --output './output'> train.log 2>&1 &
|
||||||
|
|
|
@ -1,178 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
cfg_unet_medical = {
|
|
||||||
'model': 'unet_medical',
|
|
||||||
'crop': [388 / 572, 388 / 572],
|
|
||||||
'img_size': [572, 572],
|
|
||||||
'lr': 0.0001,
|
|
||||||
'epochs': 400,
|
|
||||||
'repeat': 400,
|
|
||||||
'distribute_epochs': 1600,
|
|
||||||
'batchsize': 16,
|
|
||||||
'cross_valid_ind': 1,
|
|
||||||
'num_classes': 2,
|
|
||||||
'num_channels': 1,
|
|
||||||
|
|
||||||
'keep_checkpoint_max': 10,
|
|
||||||
'weight_decay': 0.0005,
|
|
||||||
'loss_scale': 1024.0,
|
|
||||||
'FixedLossScaleManager': 1024.0,
|
|
||||||
|
|
||||||
'resume': False,
|
|
||||||
'resume_ckpt': './',
|
|
||||||
'transfer_training': False,
|
|
||||||
'filter_weight': ['outc.weight', 'outc.bias'],
|
|
||||||
'eval_activate': 'Softmax',
|
|
||||||
'eval_resize': False
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg_unet_nested = {
|
|
||||||
'model': 'unet_nested',
|
|
||||||
'crop': None,
|
|
||||||
'img_size': [576, 576],
|
|
||||||
'lr': 0.0001,
|
|
||||||
'epochs': 400,
|
|
||||||
'repeat': 400,
|
|
||||||
'distribute_epochs': 1600,
|
|
||||||
'batchsize': 16,
|
|
||||||
'cross_valid_ind': 1,
|
|
||||||
'num_classes': 2,
|
|
||||||
'num_channels': 1,
|
|
||||||
|
|
||||||
'keep_checkpoint_max': 10,
|
|
||||||
'weight_decay': 0.0005,
|
|
||||||
'loss_scale': 1024.0,
|
|
||||||
'FixedLossScaleManager': 1024.0,
|
|
||||||
'use_bn': True,
|
|
||||||
'use_ds': True,
|
|
||||||
'use_deconv': True,
|
|
||||||
|
|
||||||
'resume': False,
|
|
||||||
'resume_ckpt': './',
|
|
||||||
'transfer_training': False,
|
|
||||||
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'],
|
|
||||||
'eval_activate': 'Softmax',
|
|
||||||
'eval_resize': False
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg_unet_nested_cell = {
|
|
||||||
'model': 'unet_nested',
|
|
||||||
'dataset': 'Cell_nuclei',
|
|
||||||
'crop': None,
|
|
||||||
'img_size': [96, 96],
|
|
||||||
'lr': 3e-4,
|
|
||||||
'epochs': 200,
|
|
||||||
'repeat': 10,
|
|
||||||
'distribute_epochs': 1600,
|
|
||||||
'batchsize': 16,
|
|
||||||
'cross_valid_ind': 1,
|
|
||||||
'num_classes': 2,
|
|
||||||
'num_channels': 3,
|
|
||||||
|
|
||||||
'keep_checkpoint_max': 10,
|
|
||||||
'weight_decay': 0.0005,
|
|
||||||
'loss_scale': 1024.0,
|
|
||||||
'FixedLossScaleManager': 1024.0,
|
|
||||||
'use_bn': True,
|
|
||||||
'use_ds': True,
|
|
||||||
'use_deconv': True,
|
|
||||||
|
|
||||||
'resume': False,
|
|
||||||
'resume_ckpt': './',
|
|
||||||
'transfer_training': False,
|
|
||||||
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'],
|
|
||||||
'eval_activate': 'Softmax',
|
|
||||||
'eval_resize': False
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg_unet_simple = {
|
|
||||||
'model': 'unet_simple',
|
|
||||||
'crop': None,
|
|
||||||
'img_size': [576, 576],
|
|
||||||
'lr': 0.0001,
|
|
||||||
'epochs': 400,
|
|
||||||
'repeat': 400,
|
|
||||||
'distribute_epochs': 2400,
|
|
||||||
'batchsize': 16,
|
|
||||||
'cross_valid_ind': 1,
|
|
||||||
'num_classes': 2,
|
|
||||||
'num_channels': 1,
|
|
||||||
|
|
||||||
'keep_checkpoint_max': 10,
|
|
||||||
'weight_decay': 0.0005,
|
|
||||||
'loss_scale': 1024.0,
|
|
||||||
'FixedLossScaleManager': 1024.0,
|
|
||||||
|
|
||||||
'resume': False,
|
|
||||||
'resume_ckpt': './',
|
|
||||||
'transfer_training': False,
|
|
||||||
'filter_weight': ["final.weight"],
|
|
||||||
'eval_activate': 'Softmax',
|
|
||||||
'eval_resize': False
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg_unet_simple_coco = {
|
|
||||||
'model': 'unet_simple',
|
|
||||||
'dataset': 'COCO',
|
|
||||||
'split': 1.0,
|
|
||||||
'img_size': [512, 512],
|
|
||||||
'lr': 3e-4,
|
|
||||||
'epochs': 80,
|
|
||||||
'repeat': 1,
|
|
||||||
'distribute_epochs': 120,
|
|
||||||
'cross_valid_ind': 1,
|
|
||||||
'batchsize': 16,
|
|
||||||
'num_channels': 3,
|
|
||||||
|
|
||||||
'keep_checkpoint_max': 10,
|
|
||||||
'weight_decay': 0.0005,
|
|
||||||
'loss_scale': 1024.0,
|
|
||||||
'FixedLossScaleManager': 1024.0,
|
|
||||||
|
|
||||||
'resume': False,
|
|
||||||
'resume_ckpt': './',
|
|
||||||
'transfer_training': False,
|
|
||||||
'filter_weight': ["final.weight"],
|
|
||||||
'eval_activate': 'Softmax',
|
|
||||||
'eval_resize': False,
|
|
||||||
|
|
||||||
'num_classes': 81,
|
|
||||||
'coco_classes': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
|
||||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
|
||||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
|
||||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
|
||||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
|
||||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
|
||||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
|
||||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
|
||||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
|
||||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
|
||||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
|
||||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
|
||||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
|
||||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
|
||||||
'teddy bear', 'hair drier', 'toothbrush'),
|
|
||||||
# change the following settings to real path
|
|
||||||
'anno_json': '/data/coco2017/annotations/instances_train2017.json',
|
|
||||||
'val_anno_json': '/data/coco2017/annotations/instances_val2017.json',
|
|
||||||
'coco_dir': '/data/coco2017/train2017',
|
|
||||||
'val_coco_dir': '/data/coco2017/val2017'
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg_unet = cfg_unet_simple
|
|
||||||
if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']:
|
|
||||||
print("ISBI dataset not support resize to original image size when in evaluation.")
|
|
||||||
cfg_unet['eval_resize'] = False
|
|
|
@ -122,8 +122,8 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
|
||||||
ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False)
|
ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False)
|
||||||
ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False)
|
ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False)
|
||||||
|
|
||||||
if do_crop:
|
if do_crop != "None":
|
||||||
resize_size = [int(img_size[x] * do_crop[x]) for x in range(len(img_size))]
|
resize_size = [int(img_size[x] * do_crop[x] / 572) for x in range(len(img_size))]
|
||||||
else:
|
else:
|
||||||
resize_size = img_size
|
resize_size = img_size
|
||||||
c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR)
|
c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR)
|
||||||
|
@ -146,7 +146,7 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
|
||||||
train_ds = train_ds.map(input_columns="image", operations=c_resize_op)
|
train_ds = train_ds.map(input_columns="image", operations=c_resize_op)
|
||||||
train_ds = train_ds.map(input_columns="mask", operations=c_resize_op)
|
train_ds = train_ds.map(input_columns="mask", operations=c_resize_op)
|
||||||
|
|
||||||
if do_crop:
|
if do_crop != "None":
|
||||||
train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
|
train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
|
||||||
post_process = data_post_process
|
post_process = data_post_process
|
||||||
train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process)
|
train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process)
|
||||||
|
@ -157,7 +157,7 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
|
||||||
valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
|
valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
|
||||||
valid_ds = ds.zip((valid_image_ds, valid_mask_ds))
|
valid_ds = ds.zip((valid_image_ds, valid_mask_ds))
|
||||||
valid_ds = valid_ds.project(columns=["image", "mask"])
|
valid_ds = valid_ds.project(columns=["image", "mask"])
|
||||||
if do_crop:
|
if do_crop != "None":
|
||||||
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
|
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
|
||||||
post_process = data_post_process
|
post_process = data_post_process
|
||||||
valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process)
|
valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process)
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Parse arguments"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
from pprint import pprint, pformat
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
_config_path = "./unet_simple_config.yaml"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""
|
||||||
|
Configuration namespace. Convert dictionary to members.
|
||||||
|
"""
|
||||||
|
def __init__(self, cfg_dict):
|
||||||
|
for k, v in cfg_dict.items():
|
||||||
|
if isinstance(v, (list, tuple)):
|
||||||
|
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||||
|
else:
|
||||||
|
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return pformat(self.__dict__)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="unet_simple_config.yaml"):
|
||||||
|
"""
|
||||||
|
Parse command line arguments to the configuration according to the default yaml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: Parent parser.
|
||||||
|
cfg: Base configuration.
|
||||||
|
helper: Helper description.
|
||||||
|
cfg_path: Path to the default yaml config.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||||
|
parents=[parser])
|
||||||
|
helper = {} if helper is None else helper
|
||||||
|
choices = {} if choices is None else choices
|
||||||
|
for item in cfg:
|
||||||
|
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||||
|
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||||
|
choice = choices[item] if item in choices else None
|
||||||
|
if isinstance(cfg[item], bool):
|
||||||
|
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||||
|
help=help_description)
|
||||||
|
else:
|
||||||
|
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||||
|
help=help_description)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml(yaml_path):
|
||||||
|
"""
|
||||||
|
Parse the yaml config file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_path: Path to the yaml config.
|
||||||
|
"""
|
||||||
|
with open(yaml_path, 'r') as fin:
|
||||||
|
try:
|
||||||
|
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||||
|
cfgs = [x for x in cfgs]
|
||||||
|
if len(cfgs) == 1:
|
||||||
|
cfg_helper = {}
|
||||||
|
cfg = cfgs[0]
|
||||||
|
elif len(cfgs) == 2:
|
||||||
|
cfg, cfg_helper = cfgs
|
||||||
|
else:
|
||||||
|
raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml")
|
||||||
|
print(cfg_helper)
|
||||||
|
except:
|
||||||
|
raise ValueError("Failed to parse yaml")
|
||||||
|
return cfg, cfg_helper
|
||||||
|
|
||||||
|
|
||||||
|
def merge(args, cfg):
|
||||||
|
"""
|
||||||
|
Merge the base config from yaml file and command line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Command line arguments.
|
||||||
|
cfg: Base configuration.
|
||||||
|
"""
|
||||||
|
args_var = vars(args)
|
||||||
|
for item in args_var:
|
||||||
|
cfg[item] = args_var[item]
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
"""
|
||||||
|
Get Config according to the yaml file and cli arguments.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../unet_simple_config.yaml"),
|
||||||
|
help="Config file path")
|
||||||
|
path_args, _ = parser.parse_known_args()
|
||||||
|
default, helper = parse_yaml(path_args.config_path)
|
||||||
|
pprint(default)
|
||||||
|
args = parse_cli_to_yaml(parser, default, helper, path_args.config_path)
|
||||||
|
final_config = merge(args, default)
|
||||||
|
return Config(final_config)
|
||||||
|
|
||||||
|
config = get_config()
|
|
@ -15,12 +15,12 @@
|
||||||
|
|
||||||
"""Device adapter for ModelArts"""
|
"""Device adapter for ModelArts"""
|
||||||
|
|
||||||
from utils.config import config
|
from src.model_utils.config import config
|
||||||
|
|
||||||
if config.enable_modelarts:
|
if config.enable_modelarts:
|
||||||
from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
else:
|
else:
|
||||||
from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
|
@ -18,7 +18,7 @@
|
||||||
import os
|
import os
|
||||||
import functools
|
import functools
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from utils.config import config
|
from src.model_utils.config import config
|
||||||
|
|
||||||
_global_sync_count = 0
|
_global_sync_count = 0
|
||||||
|
|
|
@ -21,6 +21,7 @@ from mindspore import nn
|
||||||
from mindspore.ops import operations as ops
|
from mindspore.ops import operations as ops
|
||||||
from mindspore.train.callback import Callback
|
from mindspore.train.callback import Callback
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
from src.model_utils.config import config
|
||||||
|
|
||||||
class UnetEval(nn.Cell):
|
class UnetEval(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
@ -63,10 +64,9 @@ def apply_eval(eval_param_dict):
|
||||||
|
|
||||||
class dice_coeff(nn.Metric):
|
class dice_coeff(nn.Metric):
|
||||||
"""Unet Metric, return dice coefficient and IOU."""
|
"""Unet Metric, return dice coefficient and IOU."""
|
||||||
def __init__(self, cfg_unet, print_res=True):
|
def __init__(self, print_res=True):
|
||||||
super(dice_coeff, self).__init__()
|
super(dice_coeff, self).__init__()
|
||||||
self.clear()
|
self.clear()
|
||||||
self.cfg_unet = cfg_unet
|
|
||||||
self.print_res = print_res
|
self.print_res = print_res
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
@ -84,20 +84,20 @@ class dice_coeff(nn.Metric):
|
||||||
if b != 1:
|
if b != 1:
|
||||||
raise ValueError('Batch size should be 1 when in evaluation.')
|
raise ValueError('Batch size should be 1 when in evaluation.')
|
||||||
y = y.reshape((h, w, c))
|
y = y.reshape((h, w, c))
|
||||||
if self.cfg_unet["eval_activate"].lower() == "softmax":
|
if config.eval_activate.lower() == "softmax":
|
||||||
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
|
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
|
||||||
if self.cfg_unet["eval_resize"]:
|
if config.eval_resize:
|
||||||
y_pred = []
|
y_pred = []
|
||||||
for i in range(self.cfg_unet["num_classes"]):
|
for i in range(config.num_classes):
|
||||||
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
|
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
|
||||||
y_pred = np.stack(y_pred, axis=-1)
|
y_pred = np.stack(y_pred, axis=-1)
|
||||||
else:
|
else:
|
||||||
y_pred = y_softmax
|
y_pred = y_softmax
|
||||||
elif self.cfg_unet["eval_activate"].lower() == "argmax":
|
elif config.eval_activate.lower() == "argmax":
|
||||||
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
|
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
|
||||||
y_pred = []
|
y_pred = []
|
||||||
for i in range(self.cfg_unet["num_classes"]):
|
for i in range(config.num_classes):
|
||||||
if self.cfg_unet["eval_resize"]:
|
if config.eval_resize:
|
||||||
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
|
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
|
||||||
else:
|
else:
|
||||||
y_pred.append(np.float32(y_argmax == i))
|
y_pred.append(np.float32(y_argmax == i))
|
||||||
|
|
|
@ -14,14 +14,12 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
|
||||||
import logging
|
import logging
|
||||||
import ast
|
|
||||||
|
|
||||||
import mindspore
|
import mindspore
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Model, context
|
from mindspore import Model, context
|
||||||
from mindspore.communication.management import init, get_group_size, get_rank
|
from mindspore.communication.management import init
|
||||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
@ -31,133 +29,110 @@ from src.unet_nested import NestedUNet, UNet
|
||||||
from src.data_loader import create_dataset, create_multi_class_dataset
|
from src.data_loader import create_dataset, create_multi_class_dataset
|
||||||
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
|
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
|
||||||
from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff
|
from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff
|
||||||
from src.config import cfg_unet
|
|
||||||
from src.eval_callback import EvalCallBack
|
from src.eval_callback import EvalCallBack
|
||||||
|
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
from src.model_utils.config import config
|
||||||
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
|
from src.model_utils.device_adapter import get_rank_id, get_device_num
|
||||||
|
|
||||||
|
device_id = int(os.getenv("DEVICE_ID"))
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||||
|
|
||||||
mindspore.set_seed(1)
|
mindspore.set_seed(1)
|
||||||
|
|
||||||
def train_net(args_opt,
|
@moxing_wrapper()
|
||||||
cross_valid_ind=1,
|
def train_net(cross_valid_ind=1,
|
||||||
epochs=400,
|
epochs=400,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
lr=0.0001,
|
lr=0.0001):
|
||||||
cfg=None):
|
|
||||||
rank = 0
|
rank = 0
|
||||||
group_size = 1
|
group_size = 1
|
||||||
data_dir = args_opt.data_url
|
data_dir = config.data_path
|
||||||
run_distribute = args_opt.run_distribute
|
run_distribute = config.run_distribute
|
||||||
if run_distribute:
|
if run_distribute:
|
||||||
init()
|
init()
|
||||||
group_size = get_group_size()
|
group_size = get_device_num()
|
||||||
rank = get_rank()
|
rank = get_rank_id()
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||||
device_num=group_size,
|
device_num=group_size,
|
||||||
gradients_mean=False)
|
gradients_mean=False)
|
||||||
need_slice = False
|
need_slice = False
|
||||||
if cfg['model'] == 'unet_medical':
|
if config.model_name == 'unet_medical':
|
||||||
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
net = UNetMedical(n_channels=config.num_channels, n_classes=config.num_classes)
|
||||||
elif cfg['model'] == 'unet_nested':
|
elif config.model_name == 'unet_nested':
|
||||||
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
|
net = NestedUNet(in_channel=config.num_channels, n_class=config.num_classes, use_deconv=config.use_deconv,
|
||||||
use_bn=cfg['use_bn'], use_ds=cfg['use_ds'])
|
use_bn=config.use_bn, use_ds=config.use_ds)
|
||||||
need_slice = cfg['use_ds']
|
need_slice = config.use_ds
|
||||||
elif cfg['model'] == 'unet_simple':
|
elif config.model_name == 'unet_simple':
|
||||||
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
|
net = UNet(in_channel=config.num_channels, n_class=config.num_classes)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported model: {}".format(cfg['model']))
|
raise ValueError("Unsupported model: {}".format(config.model_name))
|
||||||
|
|
||||||
if cfg['resume']:
|
if config.resume:
|
||||||
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
param_dict = load_checkpoint(config.resume_ckpt)
|
||||||
if cfg['transfer_training']:
|
if config.transfer_training:
|
||||||
filter_checkpoint_parameter_by_list(param_dict, cfg['filter_weight'])
|
filter_checkpoint_parameter_by_list(param_dict, config.filter_weight)
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
if 'use_ds' in cfg and cfg['use_ds']:
|
if hasattr(config, "use_ds") and config.use_ds:
|
||||||
criterion = MultiCrossEntropyWithLogits()
|
criterion = MultiCrossEntropyWithLogits()
|
||||||
else:
|
else:
|
||||||
criterion = CrossEntropyWithLogits()
|
criterion = CrossEntropyWithLogits()
|
||||||
if 'dataset' in cfg and cfg['dataset'] != "ISBI":
|
if hasattr(config, "dataset") and config.dataset != "ISBI":
|
||||||
repeat = cfg['repeat'] if 'repeat' in cfg else 1
|
|
||||||
split = cfg['split'] if 'split' in cfg else 0.8
|
|
||||||
dataset_sink_mode = True
|
dataset_sink_mode = True
|
||||||
per_print_times = 0
|
per_print_times = 0
|
||||||
train_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], repeat, batch_size,
|
repeat = config.repeat if hasattr(config, "repeat") else 1
|
||||||
num_classes=cfg['num_classes'], is_train=True, augment=True,
|
split = config.split if hasattr(config, "split") else 0.8
|
||||||
|
train_dataset = create_multi_class_dataset(data_dir, config.image_size, repeat, batch_size,
|
||||||
|
num_classes=config.num_classes, is_train=True, augment=True,
|
||||||
split=split, rank=rank, group_size=group_size, shuffle=True)
|
split=split, rank=rank, group_size=group_size, shuffle=True)
|
||||||
valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1,
|
valid_dataset = create_multi_class_dataset(data_dir, config.image_size, 1, 1,
|
||||||
num_classes=cfg['num_classes'], is_train=False,
|
num_classes=config.num_classes, is_train=False,
|
||||||
eval_resize=cfg["eval_resize"], split=split,
|
eval_resize=config.eval_resize, split=split,
|
||||||
python_multiprocessing=False, shuffle=False)
|
python_multiprocessing=False, shuffle=False)
|
||||||
else:
|
else:
|
||||||
repeat = cfg['repeat']
|
repeat = config.repeat
|
||||||
dataset_sink_mode = False
|
dataset_sink_mode = False
|
||||||
per_print_times = 1
|
per_print_times = 1
|
||||||
train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
|
train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
|
||||||
run_distribute, cfg["crop"], cfg['img_size'])
|
run_distribute, config.crop, config.image_size)
|
||||||
train_data_size = train_dataset.get_dataset_size()
|
train_data_size = train_dataset.get_dataset_size()
|
||||||
print("dataset length is:", train_data_size)
|
print("dataset length is:", train_data_size)
|
||||||
|
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
|
||||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||||
keep_checkpoint_max=cfg['keep_checkpoint_max'])
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']),
|
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(config.model_name),
|
||||||
directory='./ckpt_{}/'.format(device_id),
|
directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id),
|
||||||
config=ckpt_config)
|
config=ckpt_config)
|
||||||
|
|
||||||
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
|
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay,
|
||||||
loss_scale=cfg['loss_scale'])
|
loss_scale=config.loss_scale)
|
||||||
|
|
||||||
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
|
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.FixedLossScaleManager, False)
|
||||||
|
|
||||||
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
|
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
|
||||||
|
|
||||||
print("============== Starting Training ==============")
|
print("============== Starting Training ==============")
|
||||||
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
|
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
|
||||||
if args_opt.run_eval:
|
if config.run_eval:
|
||||||
eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(),
|
eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(),
|
||||||
metrics={"dice_coeff": dice_coeff(cfg_unet, False)})
|
metrics={"dice_coeff": dice_coeff(False)})
|
||||||
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": args_opt.eval_metrics}
|
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics}
|
||||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
|
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
|
||||||
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
|
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
|
||||||
ckpt_directory='./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt",
|
ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt",
|
||||||
metrics_name=args_opt.eval_metrics)
|
metrics_name=config.eval_metrics)
|
||||||
callbacks.append(eval_cb)
|
callbacks.append(eval_cb)
|
||||||
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
|
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
|
||||||
print("============== End Training ==============")
|
print("============== End Training ==============")
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
|
|
||||||
help='data directory')
|
|
||||||
parser.add_argument('-t', '--run_distribute', type=ast.literal_eval,
|
|
||||||
default=False, help='Run distribute, default: false.')
|
|
||||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
|
||||||
help="Run evaluation when training, default is False.")
|
|
||||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
|
||||||
help="Save best checkpoint when run_eval is True, default is True.")
|
|
||||||
parser.add_argument("--eval_start_epoch", type=int, default=0,
|
|
||||||
help="Evaluation start epoch when run_eval is True, default is 0.")
|
|
||||||
parser.add_argument("--eval_interval", type=int, default=1,
|
|
||||||
help="Evaluation interval when run_eval is True, default is 1.")
|
|
||||||
parser.add_argument("--eval_metrics", type=str, default="dice_coeff", choices=("dice_coeff", "iou"),
|
|
||||||
help="Evaluation metrics when run_eval is True, support [dice_coeff, iou], "
|
|
||||||
"default is dice_coeff.")
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||||
args = get_args()
|
|
||||||
print("Training setting:", args)
|
|
||||||
|
|
||||||
epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs']
|
epoch_size = config.epochs if not config.run_distribute else config.distribute_epochs
|
||||||
train_net(args_opt=args,
|
train_net(cross_valid_ind=config.cross_valid_ind,
|
||||||
cross_valid_ind=cfg_unet['cross_valid_ind'],
|
|
||||||
epochs=epoch_size,
|
epochs=epoch_size,
|
||||||
batch_size=cfg_unet['batchsize'],
|
batch_size=config.batch_size,
|
||||||
lr=cfg_unet['lr'],
|
lr=config.lr)
|
||||||
cfg=cfg_unet)
|
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path/"
|
||||||
|
device_target: 'Ascend'
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
model_name: 'unet_medical'
|
||||||
|
run_eval: False
|
||||||
|
run_distribute: False
|
||||||
|
crop: [388, 388]
|
||||||
|
image_size : [572, 572]
|
||||||
|
lr: 0.0001
|
||||||
|
epochs: 400
|
||||||
|
repeat: 400
|
||||||
|
distribute_epochs: 1600
|
||||||
|
batch_size: 16
|
||||||
|
cross_valid_ind: 1
|
||||||
|
num_classes: 2
|
||||||
|
num_channels: 1
|
||||||
|
weight_decay: 0.0005
|
||||||
|
loss_scale: 1024.0
|
||||||
|
FixedLossScaleManager: 1024.0
|
||||||
|
resume: False
|
||||||
|
resume_ckpt: './'
|
||||||
|
transfer_training: False
|
||||||
|
filter_weight: ['outc.weight', 'outc.bias']
|
||||||
|
|
||||||
|
#Eval options
|
||||||
|
keep_checkpoint_max: 10
|
||||||
|
eval_activate: 'Softmax'
|
||||||
|
eval_resize: False
|
||||||
|
checkpoint_path: './checkpoint/'
|
||||||
|
checkpoint_file_path: 'ckpt_unet_medical_adam-4-75.ckpt'
|
||||||
|
rst_path: './result_Files/'
|
||||||
|
|
||||||
|
# Export options
|
||||||
|
width: 572
|
||||||
|
height: 572
|
||||||
|
file_name: unet
|
||||||
|
file_format: AIR
|
||||||
|
|
||||||
|
---
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||||
|
data_url: 'Dataset url for obs'
|
||||||
|
train_url: 'Training output url for obs'
|
||||||
|
checkpoint_url: 'The location of checkpoint for obs'
|
||||||
|
data_path: 'Dataset path for local'
|
||||||
|
output_path: 'Training output path for local'
|
||||||
|
load_path: 'The location of checkpoint for obs'
|
||||||
|
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
num_classes: 'Class for dataset'
|
||||||
|
batch_size: "Batch size for training and evaluation"
|
||||||
|
weight_decay: "Weight decay."
|
||||||
|
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||||
|
checkpoint_path: "The location of the checkpoint file."
|
||||||
|
checkpoint_file_path: "The location of the checkpoint file."
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path/"
|
||||||
|
device_target: 'Ascend'
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
model_name: 'unet_nested'
|
||||||
|
run_eval: False
|
||||||
|
run_distribute: False
|
||||||
|
dataset: 'Cell_nuclei'
|
||||||
|
crop: None
|
||||||
|
image_size : [96, 96]
|
||||||
|
lr: 0.0003
|
||||||
|
epochs: 200
|
||||||
|
repeat: 10
|
||||||
|
distribute_epochs: 1600
|
||||||
|
batch_size: 16
|
||||||
|
cross_valid_ind: 1
|
||||||
|
num_classes: 2
|
||||||
|
num_channels: 3
|
||||||
|
weight_decay: 0.0005
|
||||||
|
loss_scale: 1024.0
|
||||||
|
FixedLossScaleManager: 1024.0
|
||||||
|
use_ds: False
|
||||||
|
use_bn: False
|
||||||
|
use_deconv: True
|
||||||
|
resume: False
|
||||||
|
resume_ckpt: './'
|
||||||
|
transfer_training: False
|
||||||
|
filter_weight: ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
|
||||||
|
|
||||||
|
#Eval options
|
||||||
|
keep_checkpoint_max: 10
|
||||||
|
eval_activate: 'Softmax'
|
||||||
|
eval_resize: False
|
||||||
|
checkpoint_path: './checkpoint/'
|
||||||
|
checkpoint_file_path: 'ckpt_unet_nested_adam-4-75.ckpt'
|
||||||
|
rst_path: './result_Files/'
|
||||||
|
|
||||||
|
# Export options
|
||||||
|
width: 572
|
||||||
|
height: 572
|
||||||
|
file_name: unet
|
||||||
|
file_format: AIR
|
||||||
|
|
||||||
|
---
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||||
|
data_url: 'Dataset url for obs'
|
||||||
|
train_url: 'Training output url for obs'
|
||||||
|
checkpoint_url: 'The location of checkpoint for obs'
|
||||||
|
data_path: 'Dataset path for local'
|
||||||
|
output_path: 'Training output path for local'
|
||||||
|
load_path: 'The location of checkpoint for obs'
|
||||||
|
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
num_classes: 'Class for dataset'
|
||||||
|
batch_size: "Batch size for training and evaluation"
|
||||||
|
weight_decay: "Weight decay."
|
||||||
|
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||||
|
checkpoint_path: "The location of the checkpoint file."
|
||||||
|
checkpoint_file_path: "The location of the checkpoint file."
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path/"
|
||||||
|
device_target: 'Ascend'
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
model_name: 'unet_nested'
|
||||||
|
run_eval: False
|
||||||
|
run_distribute: False
|
||||||
|
crop: None
|
||||||
|
image_size : [576, 576]
|
||||||
|
lr: 0.0001
|
||||||
|
epochs: 400
|
||||||
|
repeat: 400
|
||||||
|
distribute_epochs: 1600
|
||||||
|
batch_size: 16
|
||||||
|
cross_valid_ind: 1
|
||||||
|
num_classes: 2
|
||||||
|
num_channels: 1
|
||||||
|
weight_decay: 0.0005
|
||||||
|
loss_scale: 1024.0
|
||||||
|
FixedLossScaleManager: 1024.0
|
||||||
|
use_ds: True
|
||||||
|
use_bn: True
|
||||||
|
use_deconv: True
|
||||||
|
resume: False
|
||||||
|
resume_ckpt: './'
|
||||||
|
transfer_training: False
|
||||||
|
filter_weight: ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
|
||||||
|
|
||||||
|
#Eval options
|
||||||
|
keep_checkpoint_max: 10
|
||||||
|
eval_activate: 'Softmax'
|
||||||
|
eval_resize: False
|
||||||
|
checkpoint_path: './checkpoint/'
|
||||||
|
checkpoint_file_path: 'ckpt_unet_nested_adam-4-75.ckpt'
|
||||||
|
rst_path: './result_Files/'
|
||||||
|
|
||||||
|
# Export options
|
||||||
|
width: 572
|
||||||
|
height: 572
|
||||||
|
file_name: unet
|
||||||
|
file_format: AIR
|
||||||
|
|
||||||
|
---
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||||
|
data_url: 'Dataset url for obs'
|
||||||
|
train_url: 'Training output url for obs'
|
||||||
|
checkpoint_url: 'The location of checkpoint for obs'
|
||||||
|
data_path: 'Dataset path for local'
|
||||||
|
output_path: 'Training output path for local'
|
||||||
|
load_path: 'The location of checkpoint for obs'
|
||||||
|
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
num_classes: 'Class for dataset'
|
||||||
|
batch_size: "Batch size for training and evaluation"
|
||||||
|
weight_decay: "Weight decay."
|
||||||
|
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||||
|
checkpoint_path: "The location of the checkpoint file."
|
||||||
|
checkpoint_file_path: "The location of the checkpoint file."
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path/"
|
||||||
|
device_target: 'Ascend'
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
model_name: 'unet_simple'
|
||||||
|
run_eval: False
|
||||||
|
run_distribute: False
|
||||||
|
dataset: 'COCO'
|
||||||
|
split : 1.0
|
||||||
|
image_size : [512, 512]
|
||||||
|
lr: 0.0001
|
||||||
|
epochs: 80
|
||||||
|
repeat: 1
|
||||||
|
distribute_epochs: 120
|
||||||
|
batch_size: 16
|
||||||
|
cross_valid_ind: 1
|
||||||
|
num_classes: 81
|
||||||
|
num_channels: 3
|
||||||
|
weight_decay: 0.0005
|
||||||
|
loss_scale: 1024.0
|
||||||
|
FixedLossScaleManager: 1024.0
|
||||||
|
resume: False
|
||||||
|
resume_ckpt: './'
|
||||||
|
transfer_training: False
|
||||||
|
filter_weight: ['final1.weight']
|
||||||
|
coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||||
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||||
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||||
|
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||||
|
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||||
|
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||||
|
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||||
|
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||||
|
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||||
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||||
|
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||||
|
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||||
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||||
|
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||||
|
'teddy bear', 'hair drier', 'toothbrush']
|
||||||
|
|
||||||
|
anno_json: '/data/coco2017/annotations/instances_train2017.json'
|
||||||
|
val_anno_json: '/data/coco2017/annotations/instances_val2017.json'
|
||||||
|
coco_dir: '/data/coco2017/train2017'
|
||||||
|
val_coco_dir: '/data/coco2017/val2017'
|
||||||
|
|
||||||
|
#Eval options
|
||||||
|
eval_metrics: "dice_coeff"
|
||||||
|
eval_start_epoch: 0
|
||||||
|
eval_interval: 1
|
||||||
|
keep_checkpoint_max: 10
|
||||||
|
eval_activate: 'Softmax'
|
||||||
|
eval_resize: False
|
||||||
|
checkpoint_path: './checkpoint/'
|
||||||
|
checkpoint_file_path: 'ckpt_unet_simple_adam-4-75.ckpt'
|
||||||
|
rst_path: './result_Files/'
|
||||||
|
|
||||||
|
# Export options
|
||||||
|
width: 572
|
||||||
|
height: 572
|
||||||
|
file_name: unet
|
||||||
|
file_format: AIR
|
||||||
|
|
||||||
|
---
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||||
|
data_url: 'Dataset url for obs'
|
||||||
|
train_url: 'Training output url for obs'
|
||||||
|
checkpoint_url: 'The location of checkpoint for obs'
|
||||||
|
data_path: 'Dataset path for local'
|
||||||
|
output_path: 'Training output path for local'
|
||||||
|
load_path: 'The location of checkpoint for obs'
|
||||||
|
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
num_classes: 'Class for dataset'
|
||||||
|
batch_size: "Batch size for training and evaluation"
|
||||||
|
weight_decay: "Weight decay."
|
||||||
|
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||||
|
checkpoint_path: "The location of the checkpoint file."
|
||||||
|
checkpoint_file_path: "The location of the checkpoint file."
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path/"
|
||||||
|
device_target: 'Ascend'
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
model_name: 'unet_simple'
|
||||||
|
run_eval: False
|
||||||
|
run_distribute: False
|
||||||
|
crop: None
|
||||||
|
image_size : [576, 576]
|
||||||
|
lr: 0.0001
|
||||||
|
epochs: 400
|
||||||
|
repeat: 400
|
||||||
|
distribute_epochs: 2400
|
||||||
|
batch_size: 16
|
||||||
|
cross_valid_ind: 1
|
||||||
|
num_classes: 2
|
||||||
|
num_channels: 1
|
||||||
|
weight_decay: 0.0005
|
||||||
|
loss_scale: 1024.0
|
||||||
|
FixedLossScaleManager: 1024.0
|
||||||
|
resume: False
|
||||||
|
resume_ckpt: './'
|
||||||
|
transfer_training: False
|
||||||
|
filter_weight: ['final1.weight']
|
||||||
|
|
||||||
|
#Eval options
|
||||||
|
keep_checkpoint_max: 10
|
||||||
|
eval_activate: 'Softmax'
|
||||||
|
eval_resize: False
|
||||||
|
checkpoint_path: './checkpoint/'
|
||||||
|
checkpoint_file_path: 'ckpt_unet_simple_adam-4-75.ckpt'
|
||||||
|
rst_path: './result_Files/'
|
||||||
|
|
||||||
|
# Export options
|
||||||
|
width: 572
|
||||||
|
height: 572
|
||||||
|
file_name: unet
|
||||||
|
file_format: AIR
|
||||||
|
|
||||||
|
---
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||||
|
data_url: 'Dataset url for obs'
|
||||||
|
train_url: 'Training output url for obs'
|
||||||
|
checkpoint_url: 'The location of checkpoint for obs'
|
||||||
|
data_path: 'Dataset path for local'
|
||||||
|
output_path: 'Training output path for local'
|
||||||
|
load_path: 'The location of checkpoint for obs'
|
||||||
|
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
num_classes: 'Class for dataset'
|
||||||
|
batch_size: "Batch size for training and evaluation"
|
||||||
|
weight_decay: "Weight decay."
|
||||||
|
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||||
|
checkpoint_path: "The location of the checkpoint file."
|
||||||
|
checkpoint_file_path: "The location of the checkpoint file."
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
if [ $# -ne 2 ]
|
if [ $# -ne 2 ]
|
||||||
then
|
then
|
||||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]"
|
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [DATA_PATH]"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,14 @@ if [ $# != 2 ]
|
||||||
then
|
then
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
|
echo "bash scripts/run_standalone_eval.sh [DATA_PATH] [CHECKPOINT]"
|
||||||
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
|
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $# != 2 ]
|
if [ $# != 2 ]
|
||||||
then
|
then
|
||||||
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]"
|
echo "Usage: sh run_eval_ascend.sh [DATA_PATH] [CHECKPOINT_PATH]"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
if [ $# -ne 1 ]
|
if [ $# -ne 1 ]
|
||||||
then
|
then
|
||||||
echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]"
|
echo "Usage: sh run_distribute_train_ascend.sh [DATA_PATH]"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -100,26 +100,26 @@ If you want to run in modelarts, please check the official documentation of [mod
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
├── model_zoo
|
├── model_zoo
|
||||||
├── README.md // descriptions about all the models
|
├── README.md // descriptions about all the models
|
||||||
├── textcnn
|
├── textcnn
|
||||||
├── README.md // descriptions about textcnn
|
├── README.md // descriptions about textcnn
|
||||||
├──scripts
|
├──scripts
|
||||||
│ ├── run_train.sh // shell script for distributed on Ascend
|
│ ├── run_train.sh // shell script for distributed on Ascend
|
||||||
│ ├── run_eval.sh // shell script for evaluation on Ascend
|
│ ├── run_eval.sh // shell script for evaluation on Ascend
|
||||||
├── src
|
├── src
|
||||||
│ ├── dataset.py // Processing dataset
|
│ ├── dataset.py // Processing dataset
|
||||||
│ ├── textcnn.py // textcnn architecture
|
│ ├── textcnn.py // textcnn architecture
|
||||||
├── utils
|
├── utils
|
||||||
│ ├──device_adapter.py // device adapter
|
│ ├──device_adapter.py // device adapter
|
||||||
│ ├──local_adapter.py // local adapter
|
│ ├──local_adapter.py // local adapter
|
||||||
│ ├──moxing_adapter.py // moxing adapter
|
│ ├──moxing_adapter.py // moxing adapter
|
||||||
│ ├── config.py // parameter analysis
|
│ ├──config.py // parameter analysis
|
||||||
├── mr_config.yaml // parameter configuration
|
├── mr_config.yaml // parameter configuration
|
||||||
├── sst2_config.yaml // parameter configuration
|
├── sst2_config.yaml // parameter configuration
|
||||||
├── subj_config.yaml // parameter configuration
|
├── subj_config.yaml // parameter configuration
|
||||||
├── train.py // training script
|
├── train.py // training script
|
||||||
├── eval.py // evaluation script
|
├── eval.py // evaluation script
|
||||||
├── export.py // export checkpoint to other format file
|
├── export.py // export checkpoint to other format file
|
||||||
```
|
```
|
||||||
|
|
||||||
## [Script Parameters](#contents)
|
## [Script Parameters](#contents)
|
||||||
|
|
|
@ -22,9 +22,9 @@ from mindspore import context
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
from utils.moxing_adapter import moxing_wrapper
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from utils.device_adapter import get_device_id
|
from src.model_utils.device_adapter import get_device_id
|
||||||
from utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.textcnn import TextCNN
|
from src.textcnn import TextCNN
|
||||||
from src.dataset import MovieReview, SST2, Subjectivity
|
from src.dataset import MovieReview, SST2, Subjectivity
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||||
|
|
||||||
from utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.textcnn import TextCNN
|
from src.textcnn import TextCNN
|
||||||
from src.dataset import MovieReview, SST2, Subjectivity
|
from src.dataset import MovieReview, SST2, Subjectivity
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt'
|
||||||
word_len: 51
|
word_len: 51
|
||||||
vec_length: 40
|
vec_length: 40
|
||||||
base_lr: 1e-3
|
base_lr: 1e-3
|
||||||
|
label_dir: ""
|
||||||
|
result_dir: ""
|
||||||
|
result_path: './preprocess_Result/'
|
||||||
|
|
||||||
# Export options
|
# Export options
|
||||||
device_id: 0
|
device_id: 0
|
||||||
|
|
|
@ -16,35 +16,21 @@
|
||||||
##############postprocess#################
|
##############postprocess#################
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import argparse
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from src.config import cfg_mr, cfg_subj, cfg_sst2
|
from src.model_utils.config import config
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='postprocess')
|
|
||||||
parser.add_argument('--label_dir', type=str, default="", help='label data dir')
|
|
||||||
parser.add_argument('--result_dir', type=str, default="", help="infer result dir")
|
|
||||||
parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2'])
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if args.dataset == 'MR':
|
|
||||||
cfg = cfg_mr
|
|
||||||
elif args.dataset == 'SUBJ':
|
|
||||||
cfg = cfg_subj
|
|
||||||
elif args.dataset == 'SST2':
|
|
||||||
cfg = cfg_sst2
|
|
||||||
|
|
||||||
file_prefix = 'textcnn_bs' + str(cfg.batch_size) + '_'
|
file_prefix = 'textcnn_bs' + str(config.batch_size) + '_'
|
||||||
|
|
||||||
metric = Accuracy()
|
metric = Accuracy()
|
||||||
metric.clear()
|
metric.clear()
|
||||||
label_list = np.load(args.label_dir, allow_pickle=True)
|
label_list = np.load(config.label_dir, allow_pickle=True)
|
||||||
|
|
||||||
for idx, label in enumerate(label_list):
|
for idx, label in enumerate(label_list):
|
||||||
pred = np.fromfile(os.path.join(args.result_dir, file_prefix + str(idx) + '_0.bin'), np.float32)
|
pred = np.fromfile(os.path.join(config.result_dir, file_prefix + str(idx) + '_0.bin'), np.float32)
|
||||||
pred = pred.reshape(cfg.batch_size, int(pred.shape[0]/cfg.batch_size))
|
pred = pred.reshape(config.batch_size, int(pred.shape[0]/config.batch_size))
|
||||||
metric.update(pred, label)
|
metric.update(pred, label)
|
||||||
accuracy = metric.eval()
|
accuracy = metric.eval()
|
||||||
print("accuracy: ", accuracy)
|
print("accuracy: ", accuracy)
|
||||||
|
|
|
@ -15,37 +15,28 @@
|
||||||
"""
|
"""
|
||||||
##############preprocess textcnn example on movie review#################
|
##############preprocess textcnn example on movie review#################
|
||||||
"""
|
"""
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from src.config import cfg_mr, cfg_subj, cfg_sst2
|
from src.model_utils.config import config
|
||||||
from src.dataset import MovieReview, SST2, Subjectivity
|
from src.dataset import MovieReview, SST2, Subjectivity
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='TextCNN')
|
|
||||||
parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2'])
|
|
||||||
parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
|
|
||||||
args_opt = parser.parse_args()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if args_opt.dataset == 'MR':
|
if config.dataset == 'MR':
|
||||||
cfg = cfg_mr
|
instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9)
|
||||||
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
elif config.dataset == 'SUBJ':
|
||||||
elif args_opt.dataset == 'SUBJ':
|
instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9)
|
||||||
cfg = cfg_subj
|
elif config.dataset == 'SST2':
|
||||||
instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9)
|
||||||
elif args_opt.dataset == 'SST2':
|
|
||||||
cfg = cfg_sst2
|
|
||||||
instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
|
||||||
|
|
||||||
dataset = instance.create_test_dataset(batch_size=cfg.batch_size)
|
dataset = instance.create_test_dataset(batch_size=config.batch_size)
|
||||||
img_path = os.path.join(args_opt.result_path, "00_data")
|
img_path = os.path.join(config.result_path, "00_data")
|
||||||
os.makedirs(img_path)
|
os.makedirs(img_path)
|
||||||
label_list = []
|
label_list = []
|
||||||
for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||||
file_name = "textcnn_bs" + str(cfg.batch_size) + "_" + str(i) + ".bin"
|
file_name = "textcnn_bs" + str(config.batch_size) + "_" + str(i) + ".bin"
|
||||||
file_path = img_path + "/" + file_name
|
file_path = img_path + "/" + file_name
|
||||||
data['data'].tofile(file_path)
|
data['data'].tofile(file_path)
|
||||||
label_list.append(data['label'])
|
label_list.append(data['label'])
|
||||||
|
|
||||||
np.save(args_opt.result_path + "label_ids.npy", label_list)
|
np.save(config.result_path + "label_ids.npy", label_list)
|
||||||
print("="*20, "export bin files finished", "="*20)
|
print("="*20, "export bin files finished", "="*20)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Device adapter for ModelArts"""
|
||||||
|
|
||||||
|
from src.model_utils.config import config
|
||||||
|
|
||||||
|
if config.enable_modelarts:
|
||||||
|
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
else:
|
||||||
|
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||||
|
]
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Local adapter"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_device_id():
|
||||||
|
device_id = os.getenv('DEVICE_ID', '0')
|
||||||
|
return int(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
device_num = os.getenv('RANK_SIZE', '1')
|
||||||
|
return int(device_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_id():
|
||||||
|
global_rank_id = os.getenv('RANK_ID', '0')
|
||||||
|
return int(global_rank_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
return "Local Job"
|
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Moxing adapter for ModelArts"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import functools
|
||||||
|
from mindspore import context
|
||||||
|
from src.model_utils.config import config
|
||||||
|
|
||||||
|
_global_sync_count = 0
|
||||||
|
|
||||||
|
def get_device_id():
|
||||||
|
device_id = os.getenv('DEVICE_ID', '0')
|
||||||
|
return int(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
device_num = os.getenv('RANK_SIZE', '1')
|
||||||
|
return int(device_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_id():
|
||||||
|
global_rank_id = os.getenv('RANK_ID', '0')
|
||||||
|
return int(global_rank_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
job_id = os.getenv('JOB_ID')
|
||||||
|
job_id = job_id if job_id != "" else "default"
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
def sync_data(from_path, to_path):
|
||||||
|
"""
|
||||||
|
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||||
|
Upload data from local directory to remote obs in contrast.
|
||||||
|
"""
|
||||||
|
import moxing as mox
|
||||||
|
import time
|
||||||
|
global _global_sync_count
|
||||||
|
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||||
|
_global_sync_count += 1
|
||||||
|
|
||||||
|
# Each server contains 8 devices as most.
|
||||||
|
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||||
|
print("from path: ", from_path)
|
||||||
|
print("to path: ", to_path)
|
||||||
|
mox.file.copy_parallel(from_path, to_path)
|
||||||
|
print("===finish data synchronization===")
|
||||||
|
try:
|
||||||
|
os.mknod(sync_lock)
|
||||||
|
except IOError:
|
||||||
|
pass
|
||||||
|
print("===save flag===")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if os.path.exists(sync_lock):
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||||
|
|
||||||
|
|
||||||
|
def moxing_wrapper(pre_process=None, post_process=None):
|
||||||
|
"""
|
||||||
|
Moxing wrapper to download dataset and upload outputs.
|
||||||
|
"""
|
||||||
|
def wrapper(run_func):
|
||||||
|
@functools.wraps(run_func)
|
||||||
|
def wrapped_func(*args, **kwargs):
|
||||||
|
# Download data from data_url
|
||||||
|
if config.enable_modelarts:
|
||||||
|
if config.data_url:
|
||||||
|
sync_data(config.data_url, config.data_path)
|
||||||
|
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||||
|
if config.checkpoint_url:
|
||||||
|
sync_data(config.checkpoint_url, config.load_path)
|
||||||
|
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||||
|
if config.train_url:
|
||||||
|
sync_data(config.train_url, config.output_path)
|
||||||
|
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||||
|
|
||||||
|
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||||
|
config.device_num = get_device_num()
|
||||||
|
config.device_id = get_device_id()
|
||||||
|
if not os.path.exists(config.output_path):
|
||||||
|
os.makedirs(config.output_path)
|
||||||
|
|
||||||
|
if pre_process:
|
||||||
|
pre_process()
|
||||||
|
|
||||||
|
run_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Upload data to train_url
|
||||||
|
if config.enable_modelarts:
|
||||||
|
if post_process:
|
||||||
|
post_process()
|
||||||
|
|
||||||
|
if config.train_url:
|
||||||
|
print("Start to copy output directory")
|
||||||
|
sync_data(config.output_path, config.train_url)
|
||||||
|
return wrapped_func
|
||||||
|
return wrapper
|
|
@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt'
|
||||||
word_len: 51
|
word_len: 51
|
||||||
vec_length: 40
|
vec_length: 40
|
||||||
base_lr: 5e-3
|
base_lr: 5e-3
|
||||||
|
label_dir: ""
|
||||||
|
result_dir: ""
|
||||||
|
result_path: './preprocess_Result/'
|
||||||
|
|
||||||
# Export options
|
# Export options
|
||||||
device_id: 0
|
device_id: 0
|
||||||
|
|
|
@ -25,6 +25,9 @@ checkpoint_file_path: 'train_textcnn-4_149.ckpt'
|
||||||
word_len: 51
|
word_len: 51
|
||||||
vec_length: 40
|
vec_length: 40
|
||||||
base_lr: 8e-4
|
base_lr: 8e-4
|
||||||
|
label_dir: ""
|
||||||
|
result_dir: ""
|
||||||
|
result_path: './preprocess_Result/'
|
||||||
|
|
||||||
# Export options
|
# Export options
|
||||||
device_id: 0
|
device_id: 0
|
||||||
|
|
|
@ -26,9 +26,9 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
from utils.moxing_adapter import moxing_wrapper
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from utils.device_adapter import get_device_id, get_rank_id
|
from src.model_utils.device_adapter import get_device_id, get_rank_id
|
||||||
from utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.textcnn import TextCNN
|
from src.textcnn import TextCNN
|
||||||
from src.textcnn import SoftmaxCrossEntropyExpand
|
from src.textcnn import SoftmaxCrossEntropyExpand
|
||||||
from src.dataset import MovieReview, SST2, Subjectivity
|
from src.dataset import MovieReview, SST2, Subjectivity
|
||||||
|
|
Loading…
Reference in New Issue