From 299539406454cae419306ffc4d1c664b0e850bff Mon Sep 17 00:00:00 2001 From: zhaoting Date: Fri, 7 May 2021 14:39:21 +0800 Subject: [PATCH] add coco convert support for unet --- model_zoo/official/cv/unet/README.md | 54 ++++++- model_zoo/official/cv/unet/README_CN.md | 54 ++++++- model_zoo/official/cv/unet/eval.py | 11 +- .../official/cv/unet/preprocess_dataset.py | 138 ++++++++++++++++++ .../cv/unet/scripts/run_distribute_train.sh | 22 ++- .../cv/unet/scripts/run_standalone_eval.sh | 29 +++- .../cv/unet/scripts/run_standalone_train.sh | 29 +++- model_zoo/official/cv/unet/src/config.py | 48 ++++++ model_zoo/official/cv/unet/src/data_loader.py | 83 ++++++----- model_zoo/official/cv/unet/train.py | 22 +-- 10 files changed, 413 insertions(+), 77 deletions(-) create mode 100644 model_zoo/official/cv/unet/preprocess_dataset.py diff --git a/model_zoo/official/cv/unet/README.md b/model_zoo/official/cv/unet/README.md index 1dad8f63255..58ccbdfbf5d 100644 --- a/model_zoo/official/cv/unet/README.md +++ b/model_zoo/official/cv/unet/README.md @@ -53,7 +53,58 @@ Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home) - Data format:binary files(TIF file) - Note:Data will be processed in src/data_loader.py -We also support cell nuclei dataset which is used in [Unet++ original paper](https://arxiv.org/abs/1912.05074). If you want to use the dataset, please add `'dataset': 'Cell_nuclei'` in `src/config.py`. +We also support Multi-Class dataset which get image path and mask path from a tree of directories. +Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`. +The directory structure is as follows: + +```path +. +└─dataset + └─0001 + ├─image.png + └─mask.png + └─0002 + ├─image.png + └─mask.png + ... + └─xxxx + ├─image.png + └─mask.png +``` + +When you set `split` in (0, 1) in config, all images will be split to train dataset and val dataset by split value, and the `split` default is 0.8. +If set `split`=1.0, you should split train dataset and val dataset by directories, the directory structure is as follows: + +```path +. +└─dataset + └─train + └─0001 + ├─image.png + └─mask.png + ... + └─xxxx + ├─image.png + └─mask.png + └─val + └─0001 + ├─image.png + └─mask.png + ... + └─xxxx + ├─image.png + └─mask.png +``` + +We support script to convert COCO and a Cell_Nuclei dataset used in used in [Unet++ original paper](https://arxiv.org/abs/1912.05074) to mulyi-class dataset format. + +1. Change `cfg_unet` in `src/config.py`, you can refer to `cfg_unet_nested_cell` and `cfg_unet_simple_coco` in `src/config.py` for detail. + +2. run script to convert to mulyi-class dataset format: + +```shell +python preprocess_dataset.py -d /data/save_data_path +``` ## [Environment Requirements](#contents) @@ -145,6 +196,7 @@ Then you can run everything just like on ascend. ├── mindspore_hub_conf.py // hub config file ├── postprocess.py // unet 310 infer postprocess. ├── preprocess.py // unet 310 infer preprocess dataset + ├── preprocess_dataset.py // the script to adapt MultiClass dataset ├── requirements.txt // Requirements of third party package. ``` diff --git a/model_zoo/official/cv/unet/README_CN.md b/model_zoo/official/cv/unet/README_CN.md index 5d5e79d156e..036852b4b2e 100644 --- a/model_zoo/official/cv/unet/README_CN.md +++ b/model_zoo/official/cv/unet/README_CN.md @@ -57,7 +57,58 @@ UNet++是U-Net的增强版本,使用了新的跨层链接方式和深层监督 - 数据格式:二进制文件(TIF) - 注意:数据在src/data_loader.py中处理 -我们也支持一个在 [Unet++](https://arxiv.org/abs/1912.05074) 原论文中使用的数据集 `Cell_nuclei`。可以通过修改`src/config.py`中`'dataset': 'Cell_nuclei'`配置使用. +我们也支持一种 Multi-Class 数据集格式,通过固定的目录结构获取图片和对应标签数据。 +在同一个目录中保存原图片及对应标签,其中图片名为 `"image.png"`,标签名为 `"mask.png"`。 +目录结构如下: + +```path +. +└─dataset + └─0001 + ├─image.png + └─mask.png + └─0002 + ├─image.png + └─mask.png + ... + └─xxxx + ├─image.png + └─mask.png +``` + +通过在`config`中的`split`参数将所有的图片分为训练集和验证集,`split` 默认为 0.8。 +当设置 `split`为 1.0时,通过目录来分训练集和验证集,目录结构如下: + +```path +. +└─dataset + └─train + └─0001 + ├─image.png + └─mask.png + ... + └─xxxx + ├─image.png + └─mask.png + └─val + └─0001 + ├─image.png + └─mask.png + ... + └─xxxx + ├─image.png + └─mask.png +``` + +我们提供了一个脚本来将 COCO 和 Cell_Nuclei 数据集([Unet++ 原论文](https://arxiv.org/abs/1912.05074) 中使用)转换为multi-class格式。 + +1. 在`src/config.py`中修改`cfg_unet`,修改细节请参考`src/config.py`中的`cfg_unet_nested_cell` 和 `cfg_unet_simple_coco`。 + +2. 运行转换脚本: + +```shell +python preprocess_dataset.py -d /data/save_data_path +``` ## 环境要求 @@ -149,6 +200,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR] ├── mindspore_hub_conf.py // hub 配置脚本 ├── postprocess.py // 310 推理后处理脚本 ├── preprocess.py // 310 推理前处理脚本 + ├── preprocess_dataset.py // 适配MultiClass数据集脚本 ├── requirements.txt // 需要的三方库. ``` diff --git a/model_zoo/official/cv/unet/eval.py b/model_zoo/official/cv/unet/eval.py index d043591ff03..4cdc8ba6977 100644 --- a/model_zoo/official/cv/unet/eval.py +++ b/model_zoo/official/cv/unet/eval.py @@ -19,7 +19,7 @@ import logging from mindspore import context, Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.data_loader import create_dataset, create_cell_nuclei_dataset +from src.data_loader import create_dataset, create_multi_class_dataset from src.unet_medical import UNetMedical from src.unet_nested import NestedUNet, UNet from src.config import cfg_unet @@ -44,9 +44,12 @@ def test_net(data_dir, param_dict = load_checkpoint(ckpt_path) load_param_into_net(net, param_dict) net = UnetEval(net) - if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": - valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, - eval_resize=cfg["eval_resize"], split=0.8) + if 'dataset' in cfg and cfg['dataset'] != "ISBI": + split = cfg['split'] if 'split' in cfg else 0.8 + valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1, + num_classes=cfg['num_classes'], is_train=False, + eval_resize=cfg["eval_resize"], split=split, + python_multiprocessing=False, shuffle=False) else: _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], img_size=cfg['img_size']) diff --git a/model_zoo/official/cv/unet/preprocess_dataset.py b/model_zoo/official/cv/unet/preprocess_dataset.py new file mode 100644 index 00000000000..550ac8e6918 --- /dev/null +++ b/model_zoo/official/cv/unet/preprocess_dataset.py @@ -0,0 +1,138 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Preprocess dataset. +Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`. +""" +import os +import argparse +import cv2 +import numpy as np +from model_zoo.official.cv.unet.src.config import cfg_unet + +def annToMask(ann, height, width): + """Convert annotation to RLE and then to binary mask.""" + from pycocotools import mask as maskHelper + segm = ann['segmentation'] + if isinstance(segm, list): + rles = maskHelper.frPyObjects(segm, height, width) + rle = maskHelper.merge(rles) + elif isinstance(segm['counts'], list): + rle = maskHelper.frPyObjects(segm, height, width) + else: + rle = ann['segmentation'] + m = maskHelper.decode(rle) + return m + +def preprocess_cell_nuclei_dataset(param_dict): + """ + Preprocess for Cell Nuclei dataset. + merge all instances to a mask, and save the mask at data_dir/img_id/mask.png. + """ + print("========== start preprocess Cell Nuclei dataset ==========") + data_dir = param_dict["data_dir"] + img_ids = sorted(next(os.walk(data_dir))[1]) + for img_id in img_ids: + path = os.path.join(data_dir, img_id) + if (not os.path.exists(os.path.join(path, "image.png"))) or \ + (not os.path.exists(os.path.join(path, "mask.png"))): + img = cv2.imread(os.path.join(path, "images", img_id + ".png")) + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + img = np.concatenate([img, img, img], axis=-1) + mask = [] + for mask_file in next(os.walk(os.path.join(path, "masks")))[2]: + mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE) + mask.append(mask_) + mask = np.max(mask, axis=0) + cv2.imwrite(os.path.join(path, "image.png"), img) + cv2.imwrite(os.path.join(path, "mask.png"), mask) + +def preprocess_coco_dataset(param_dict): + """ + Preprocess for coco dataset. + Save image and mask at save_dir/img_name/image.png save_dir/img_name/mask.png + """ + print("========== start preprocess coco dataset ==========") + from pycocotools.coco import COCO + anno_json = param_dict["anno_json"] + coco_cls = param_dict["coco_classes"] + coco_dir = param_dict["coco_dir"] + save_dir = param_dict["save_dir"] + coco_cls_dict = {} + for i, cls in enumerate(coco_cls): + coco_cls_dict[cls] = i + coco = COCO(anno_json) + classs_dict = {} + cat_ids = coco.loadCats(coco.getCatIds()) + for cat in cat_ids: + classs_dict[cat["id"]] = cat["name"] + image_ids = coco.getImgIds() + images_num = len(image_ids) + for ind, img_id in enumerate(image_ids): + image_info = coco.loadImgs(img_id) + file_name = image_info[0]["file_name"] + img_name, _ = os.path.splitext(file_name) + image_path = os.path.join(coco_dir, file_name) + if not os.path.isfile(image_path): + print("{}/{}: {} is in annotations but not exist".format(ind + 1, images_num, image_path)) + continue + if not os.path.exists(os.path.join(save_dir, img_name)): + os.makedirs(os.path.join(save_dir, img_name)) + anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = coco.loadAnns(anno_ids) + h = coco.imgs[img_id]["height"] + w = coco.imgs[img_id]["width"] + mask = np.zeros((h, w), dtype=np.uint8) + for instance in anno: + m = annToMask(instance, h, w) + c = coco_cls_dict[classs_dict[instance["category_id"]]] + if len(m.shape) < 3: + mask[:, :] += (mask == 0) * (m * c) + else: + mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) + img = cv2.imread(image_path) + cv2.imwrite(os.path.join(save_dir, img_name, "image.png"), img) + cv2.imwrite(os.path.join(save_dir, img_name, "mask.png"), mask) + +def preprocess_dataset(cfg, data_dir): + """Select preprocess function.""" + if cfg['dataset'].lower() == "cell_nuclei": + preprocess_cell_nuclei_dataset({"data_dir": data_dir}) + elif cfg['dataset'].lower() == "coco": + if 'split' in cfg and cfg['split'] == 1.0: + train_data_path = os.path.join(data_dir, "train") + val_data_path = os.path.join(data_dir, "val") + train_param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"], + "coco_dir": cfg["coco_dir"], "save_dir": train_data_path} + preprocess_coco_dataset(train_param_dict) + val_param_dict = {"anno_json": cfg["val_anno_json"], "coco_classes": cfg["coco_classes"], + "coco_dir": cfg["val_coco_dir"], "save_dir": val_data_path} + preprocess_coco_dataset(val_param_dict) + else: + param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"], + "coco_dir": cfg["coco_dir"], "save_dir": data_dir} + preprocess_coco_dataset(param_dict) + else: + raise ValueError("Not support dataset mode {}".format(cfg['dataset'])) + print("========== end preprocess dataset ==========") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', + help='save data directory') + args = parser.parse_args() + preprocess_dataset(cfg_unet, args.data_url) diff --git a/model_zoo/official/cv/unet/scripts/run_distribute_train.sh b/model_zoo/official/cv/unet/scripts/run_distribute_train.sh index 5d1555d9e99..e7749935065 100644 --- a/model_zoo/official/cv/unet/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/unet/scripts/run_distribute_train.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,14 @@ # limitations under the License. # ============================================================================ +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + if [ $# != 2 ] then echo "==============================================================================================================" @@ -24,10 +32,11 @@ then echo "==============================================================================================================" exit 1 fi - +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) export HCCL_CONNECT_TIMEOUT=600 export RANK_SIZE=8 - +DATASET=$(get_real_path $2) +export RANK_TABLE_FILE=$(get_real_path $1) for((i=0;i env.log - python3 train.py \ + python3 ${PROJECT_DIR}/../train.py \ --run_distribute=True \ - --data_url=$2 > log.txt 2>&1 & + --data_url=$DATASET > log.txt 2>&1 & cd ../ -done \ No newline at end of file +done diff --git a/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh b/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh index 2965138236a..c91e66c3c37 100644 --- a/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh +++ b/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +14,31 @@ # limitations under the License. # ============================================================================ -if [ $# != 2 ] +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +if [ $# != 2 ] && [ $# != 3 ] then echo "==============================================================================================================" echo "Please run the script as: " - echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]" - echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/" + echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [DEVICE_ID](option, default is 0)" + echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ 0" echo "==============================================================================================================" + exit 1 fi - +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) export DEVICE_ID=0 -python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 & \ No newline at end of file +if [ $# != 2 ] +then + export DEVICE_ID=$3 +fi +DATASET=$(get_real_path $1) +CHECKPOINT=$(get_real_path $2) +echo "========== start run evaluation ===========" +echo "please get log at eval.log" +python ${PROJECT_DIR}/../eval.py --data_url=$DATASET --ckpt_path=$CHECKPOINT > eval.log 2>&1 & diff --git a/model_zoo/official/cv/unet/scripts/run_standalone_train.sh b/model_zoo/official/cv/unet/scripts/run_standalone_train.sh index c5f24633bc1..a7fad97c09e 100644 --- a/model_zoo/official/cv/unet/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/unet/scripts/run_standalone_train.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +14,31 @@ # limitations under the License. # ============================================================================ -if [ $# != 1 ] +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +if [ $# != 1 ] && [ $# != 2 ] then echo "==============================================================================================================" echo "Please run the script as: " - echo "bash scripts/run_standalone_train.sh [DATASET]" - echo "for example: bash run_standalone_train.sh /path/to/data/" + echo "bash scripts/run_standalone_train.sh [DATASET] [DEVICE_ID](option, default is 0)" + echo "for example: bash run_standalone_train.sh /path/to/data/ 0" echo "==============================================================================================================" + exit 1 +fi +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export DEVICE_ID=0 +if [ $# != 1 ] +then + export DEVICE_ID=$2 fi -export DEVICE_ID=0 -python train.py --data_url=$1 > train.log 2>&1 & \ No newline at end of file +DATASET=$(get_real_path $1) +echo "========== start run training ===========" +echo "please get log at train.log" +python ${PROJECT_DIR}/../train.py --data_url=$DATASET > train.log 2>&1 & diff --git a/model_zoo/official/cv/unet/src/config.py b/model_zoo/official/cv/unet/src/config.py index 82453111408..a40cc331989 100644 --- a/model_zoo/official/cv/unet/src/config.py +++ b/model_zoo/official/cv/unet/src/config.py @@ -124,6 +124,54 @@ cfg_unet_simple = { 'eval_resize': False } +cfg_unet_simple_coco = { + 'model': 'unet_simple', + 'dataset': 'COCO', + 'split': 1.0, + 'img_size': [512, 512], + 'lr': 3e-4, + 'epochs': 80, + 'repeat': 1, + 'distribute_epochs': 120, + 'cross_valid_ind': 1, + 'batchsize': 16, + 'num_channels': 3, + + 'keep_checkpoint_max': 10, + 'weight_decay': 0.0005, + 'loss_scale': 1024.0, + 'FixedLossScaleManager': 1024.0, + + 'resume': False, + 'resume_ckpt': './', + 'transfer_training': False, + 'filter_weight': ["final.weight"], + 'eval_activate': 'Softmax', + 'eval_resize': False, + + 'num_classes': 81, + 'coco_classes': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + # change the following settings to real path + 'anno_json': '/data/coco2017/annotations/instances_train2017.json', + 'val_anno_json': '/data/coco2017/annotations/instances_val2017.json', + 'coco_dir': '/data/coco2017/train2017', + 'val_coco_dir': '/data/coco2017/val2017' +} + cfg_unet = cfg_unet_simple if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']: print("ISBI dataset not support resize to original image size when in evaluation.") diff --git a/model_zoo/official/cv/unet/src/data_loader.py b/model_zoo/official/cv/unet/src/data_loader.py index 3dee1493b74..82921647376 100644 --- a/model_zoo/official/cv/unet/src/data_loader.py +++ b/model_zoo/official/cv/unet/src/data_loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -165,38 +165,33 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro return train_ds, valid_ds -class CellNucleiDataset: +class MultiClassDataset: """ - Cell nuclei dataset preprocess class. + Read image and mask from original images, and split all data into train_dataset and val_dataset by `split`. + Get image path and mask path from a tree of directories, + images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`. """ - def __init__(self, data_dir, repeat, is_train=False, split=0.8): + def __init__(self, data_dir, repeat, is_train=False, split=0.8, shuffle=False): self.data_dir = data_dir - self.img_ids = sorted(next(os.walk(self.data_dir))[1]) - self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat - np.random.shuffle(self.train_ids) - self.val_ids = self.img_ids[int(len(self.img_ids) * split):] self.is_train = is_train - self._preprocess_dataset() - - def _preprocess_dataset(self): - for img_id in self.img_ids: - path = os.path.join(self.data_dir, img_id) - if (not os.path.exists(os.path.join(path, "image.png"))) or \ - (not os.path.exists(os.path.join(path, "mask.png"))): - img = cv2.imread(os.path.join(path, "images", img_id + ".png")) - if len(img.shape) == 2: - img = np.expand_dims(img, axis=-1) - img = np.concatenate([img, img, img], axis=-1) - mask = [] - for mask_file in next(os.walk(os.path.join(path, "masks")))[2]: - mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE) - mask.append(mask_) - mask = np.max(mask, axis=0) - cv2.imwrite(os.path.join(path, "image.png"), img) - cv2.imwrite(os.path.join(path, "mask.png"), mask) + self.split = (split != 1.0) + if self.split: + self.img_ids = sorted(next(os.walk(self.data_dir))[1]) + self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat + self.val_ids = self.img_ids[int(len(self.img_ids) * split):] + else: + self.train_ids = sorted(next(os.walk(os.path.join(self.data_dir, "train")))[1]) + self.val_ids = sorted(next(os.walk(os.path.join(self.data_dir, "val")))[1]) + if shuffle: + np.random.shuffle(self.train_ids) def _read_img_mask(self, img_id): - path = os.path.join(self.data_dir, img_id) + if self.split: + path = os.path.join(self.data_dir, img_id) + elif self.is_train: + path = os.path.join(self.data_dir, "train", img_id) + else: + path = os.path.join(self.data_dir, "val", img_id) img = cv2.imread(os.path.join(path, "image.png")) mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE) return img, mask @@ -216,9 +211,9 @@ class CellNucleiDataset: return len(self.train_ids) return len(self.val_ids) -def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False): +def preprocess_img_mask(img, mask, num_classes, img_size, augment=False, eval_resize=False): """ - Preprocess for cell nuclei dataset. + Preprocess for multi-class dataset. Random crop and flip images and masks when augment is True. """ if augment: @@ -240,24 +235,28 @@ def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False): mask = cv2.resize(mask, img_size) img = (img.astype(np.float32) - 127.5) / 127.5 img = img.transpose(2, 0, 1) - mask = mask.astype(np.float32) / 255 - mask = (mask > 0.5).astype(np.int) - mask = (np.arange(2) == mask[..., None]).astype(int) + if num_classes == 2: + mask = mask.astype(np.float32) / mask.max() + mask = (mask > 0.5).astype(np.int) + else: + mask = mask.astype(np.int) + mask = (np.arange(num_classes) == mask[..., None]).astype(int) mask = mask.transpose(2, 0, 1).astype(np.float32) return img, mask -def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, eval_resize=False, - split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8): +def create_multi_class_dataset(data_dir, img_size, repeat, batch_size, num_classes=2, is_train=False, augment=False, + eval_resize=False, split=0.8, rank=0, group_size=1, python_multiprocessing=True, + num_parallel_workers=8, shuffle=True): """ - Get generator dataset for cell nuclei dataset. + Get generator dataset for multi-class dataset. """ - cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split) - sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train) - dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler) - compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train, - eval_resize)) - dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names, - output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names, + mc_dataset = MultiClassDataset(data_dir, repeat, is_train, split, shuffle) + sampler = ds.DistributedSampler(group_size, rank, shuffle=shuffle) + dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, sampler=sampler) + compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, num_classes, tuple(img_size), + augment and is_train, eval_resize)) + dataset = dataset.map(operations=compose_map_func, input_columns=mc_dataset.column_names, + output_columns=mc_dataset.column_names, column_order=mc_dataset.column_names, python_multiprocessing=python_multiprocessing, num_parallel_workers=num_parallel_workers) dataset = dataset.batch(batch_size, drop_remainder=is_train) diff --git a/model_zoo/official/cv/unet/train.py b/model_zoo/official/cv/unet/train.py index f97e492d7b3..e0bb07f0ed7 100644 --- a/model_zoo/official/cv/unet/train.py +++ b/model_zoo/official/cv/unet/train.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.unet_medical import UNetMedical from src.unet_nested import NestedUNet, UNet -from src.data_loader import create_dataset, create_cell_nuclei_dataset +from src.data_loader import create_dataset, create_multi_class_dataset from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff from src.config import cfg_unet @@ -79,16 +79,18 @@ def train_net(args_opt, criterion = MultiCrossEntropyWithLogits() else: criterion = CrossEntropyWithLogits() - if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": - repeat = cfg['repeat'] + if 'dataset' in cfg and cfg['dataset'] != "ISBI": + repeat = cfg['repeat'] if 'repeat' in cfg else 1 + split = cfg['split'] if 'split' in cfg else 0.8 dataset_sink_mode = True per_print_times = 0 - train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size, - is_train=True, augment=True, split=0.8, rank=rank, - group_size=group_size) - valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, - eval_resize=cfg["eval_resize"], split=0.8, - python_multiprocessing=False) + train_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], repeat, batch_size, + num_classes=cfg['num_classes'], is_train=True, augment=True, + split=split, rank=rank, group_size=group_size, shuffle=True) + valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1, + num_classes=cfg['num_classes'], is_train=False, + eval_resize=cfg["eval_resize"], split=split, + python_multiprocessing=False, shuffle=False) else: repeat = cfg['repeat'] dataset_sink_mode = False