forked from mindspore-Ecosystem/mindspore
add coco convert support for unet
This commit is contained in:
parent
7a3d9f2ad7
commit
2995394064
|
@ -53,7 +53,58 @@ Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home)
|
|||
- Data format:binary files(TIF file)
|
||||
- Note:Data will be processed in src/data_loader.py
|
||||
|
||||
We also support cell nuclei dataset which is used in [Unet++ original paper](https://arxiv.org/abs/1912.05074). If you want to use the dataset, please add `'dataset': 'Cell_nuclei'` in `src/config.py`.
|
||||
We also support Multi-Class dataset which get image path and mask path from a tree of directories.
|
||||
Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
|
||||
The directory structure is as follows:
|
||||
|
||||
```path
|
||||
.
|
||||
└─dataset
|
||||
└─0001
|
||||
├─image.png
|
||||
└─mask.png
|
||||
└─0002
|
||||
├─image.png
|
||||
└─mask.png
|
||||
...
|
||||
└─xxxx
|
||||
├─image.png
|
||||
└─mask.png
|
||||
```
|
||||
|
||||
When you set `split` in (0, 1) in config, all images will be split to train dataset and val dataset by split value, and the `split` default is 0.8.
|
||||
If set `split`=1.0, you should split train dataset and val dataset by directories, the directory structure is as follows:
|
||||
|
||||
```path
|
||||
.
|
||||
└─dataset
|
||||
└─train
|
||||
└─0001
|
||||
├─image.png
|
||||
└─mask.png
|
||||
...
|
||||
└─xxxx
|
||||
├─image.png
|
||||
└─mask.png
|
||||
└─val
|
||||
└─0001
|
||||
├─image.png
|
||||
└─mask.png
|
||||
...
|
||||
└─xxxx
|
||||
├─image.png
|
||||
└─mask.png
|
||||
```
|
||||
|
||||
We support script to convert COCO and a Cell_Nuclei dataset used in used in [Unet++ original paper](https://arxiv.org/abs/1912.05074) to mulyi-class dataset format.
|
||||
|
||||
1. Change `cfg_unet` in `src/config.py`, you can refer to `cfg_unet_nested_cell` and `cfg_unet_simple_coco` in `src/config.py` for detail.
|
||||
|
||||
2. run script to convert to mulyi-class dataset format:
|
||||
|
||||
```shell
|
||||
python preprocess_dataset.py -d /data/save_data_path
|
||||
```
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
|
@ -145,6 +196,7 @@ Then you can run everything just like on ascend.
|
|||
├── mindspore_hub_conf.py // hub config file
|
||||
├── postprocess.py // unet 310 infer postprocess.
|
||||
├── preprocess.py // unet 310 infer preprocess dataset
|
||||
├── preprocess_dataset.py // the script to adapt MultiClass dataset
|
||||
├── requirements.txt // Requirements of third party package.
|
||||
```
|
||||
|
||||
|
|
|
@ -57,7 +57,58 @@ UNet++是U-Net的增强版本,使用了新的跨层链接方式和深层监督
|
|||
- 数据格式:二进制文件(TIF)
|
||||
- 注意:数据在src/data_loader.py中处理
|
||||
|
||||
我们也支持一个在 [Unet++](https://arxiv.org/abs/1912.05074) 原论文中使用的数据集 `Cell_nuclei`。可以通过修改`src/config.py`中`'dataset': 'Cell_nuclei'`配置使用.
|
||||
我们也支持一种 Multi-Class 数据集格式,通过固定的目录结构获取图片和对应标签数据。
|
||||
在同一个目录中保存原图片及对应标签,其中图片名为 `"image.png"`,标签名为 `"mask.png"`。
|
||||
目录结构如下:
|
||||
|
||||
```path
|
||||
.
|
||||
└─dataset
|
||||
└─0001
|
||||
├─image.png
|
||||
└─mask.png
|
||||
└─0002
|
||||
├─image.png
|
||||
└─mask.png
|
||||
...
|
||||
└─xxxx
|
||||
├─image.png
|
||||
└─mask.png
|
||||
```
|
||||
|
||||
通过在`config`中的`split`参数将所有的图片分为训练集和验证集,`split` 默认为 0.8。
|
||||
当设置 `split`为 1.0时,通过目录来分训练集和验证集,目录结构如下:
|
||||
|
||||
```path
|
||||
.
|
||||
└─dataset
|
||||
└─train
|
||||
└─0001
|
||||
├─image.png
|
||||
└─mask.png
|
||||
...
|
||||
└─xxxx
|
||||
├─image.png
|
||||
└─mask.png
|
||||
└─val
|
||||
└─0001
|
||||
├─image.png
|
||||
└─mask.png
|
||||
...
|
||||
└─xxxx
|
||||
├─image.png
|
||||
└─mask.png
|
||||
```
|
||||
|
||||
我们提供了一个脚本来将 COCO 和 Cell_Nuclei 数据集([Unet++ 原论文](https://arxiv.org/abs/1912.05074) 中使用)转换为multi-class格式。
|
||||
|
||||
1. 在`src/config.py`中修改`cfg_unet`,修改细节请参考`src/config.py`中的`cfg_unet_nested_cell` 和 `cfg_unet_simple_coco`。
|
||||
|
||||
2. 运行转换脚本:
|
||||
|
||||
```shell
|
||||
python preprocess_dataset.py -d /data/save_data_path
|
||||
```
|
||||
|
||||
## 环境要求
|
||||
|
||||
|
@ -149,6 +200,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
|||
├── mindspore_hub_conf.py // hub 配置脚本
|
||||
├── postprocess.py // 310 推理后处理脚本
|
||||
├── preprocess.py // 310 推理前处理脚本
|
||||
├── preprocess_dataset.py // 适配MultiClass数据集脚本
|
||||
├── requirements.txt // 需要的三方库.
|
||||
```
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import logging
|
|||
from mindspore import context, Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.data_loader import create_dataset, create_cell_nuclei_dataset
|
||||
from src.data_loader import create_dataset, create_multi_class_dataset
|
||||
from src.unet_medical import UNetMedical
|
||||
from src.unet_nested import NestedUNet, UNet
|
||||
from src.config import cfg_unet
|
||||
|
@ -44,9 +44,12 @@ def test_net(data_dir,
|
|||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net = UnetEval(net)
|
||||
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
|
||||
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
|
||||
eval_resize=cfg["eval_resize"], split=0.8)
|
||||
if 'dataset' in cfg and cfg['dataset'] != "ISBI":
|
||||
split = cfg['split'] if 'split' in cfg else 0.8
|
||||
valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1,
|
||||
num_classes=cfg['num_classes'], is_train=False,
|
||||
eval_resize=cfg["eval_resize"], split=split,
|
||||
python_multiprocessing=False, shuffle=False)
|
||||
else:
|
||||
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
|
||||
do_crop=cfg['crop'], img_size=cfg['img_size'])
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Preprocess dataset.
|
||||
Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
from model_zoo.official.cv.unet.src.config import cfg_unet
|
||||
|
||||
def annToMask(ann, height, width):
|
||||
"""Convert annotation to RLE and then to binary mask."""
|
||||
from pycocotools import mask as maskHelper
|
||||
segm = ann['segmentation']
|
||||
if isinstance(segm, list):
|
||||
rles = maskHelper.frPyObjects(segm, height, width)
|
||||
rle = maskHelper.merge(rles)
|
||||
elif isinstance(segm['counts'], list):
|
||||
rle = maskHelper.frPyObjects(segm, height, width)
|
||||
else:
|
||||
rle = ann['segmentation']
|
||||
m = maskHelper.decode(rle)
|
||||
return m
|
||||
|
||||
def preprocess_cell_nuclei_dataset(param_dict):
|
||||
"""
|
||||
Preprocess for Cell Nuclei dataset.
|
||||
merge all instances to a mask, and save the mask at data_dir/img_id/mask.png.
|
||||
"""
|
||||
print("========== start preprocess Cell Nuclei dataset ==========")
|
||||
data_dir = param_dict["data_dir"]
|
||||
img_ids = sorted(next(os.walk(data_dir))[1])
|
||||
for img_id in img_ids:
|
||||
path = os.path.join(data_dir, img_id)
|
||||
if (not os.path.exists(os.path.join(path, "image.png"))) or \
|
||||
(not os.path.exists(os.path.join(path, "mask.png"))):
|
||||
img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
|
||||
if len(img.shape) == 2:
|
||||
img = np.expand_dims(img, axis=-1)
|
||||
img = np.concatenate([img, img, img], axis=-1)
|
||||
mask = []
|
||||
for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
|
||||
mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
|
||||
mask.append(mask_)
|
||||
mask = np.max(mask, axis=0)
|
||||
cv2.imwrite(os.path.join(path, "image.png"), img)
|
||||
cv2.imwrite(os.path.join(path, "mask.png"), mask)
|
||||
|
||||
def preprocess_coco_dataset(param_dict):
|
||||
"""
|
||||
Preprocess for coco dataset.
|
||||
Save image and mask at save_dir/img_name/image.png save_dir/img_name/mask.png
|
||||
"""
|
||||
print("========== start preprocess coco dataset ==========")
|
||||
from pycocotools.coco import COCO
|
||||
anno_json = param_dict["anno_json"]
|
||||
coco_cls = param_dict["coco_classes"]
|
||||
coco_dir = param_dict["coco_dir"]
|
||||
save_dir = param_dict["save_dir"]
|
||||
coco_cls_dict = {}
|
||||
for i, cls in enumerate(coco_cls):
|
||||
coco_cls_dict[cls] = i
|
||||
coco = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco.loadCats(coco.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["id"]] = cat["name"]
|
||||
image_ids = coco.getImgIds()
|
||||
images_num = len(image_ids)
|
||||
for ind, img_id in enumerate(image_ids):
|
||||
image_info = coco.loadImgs(img_id)
|
||||
file_name = image_info[0]["file_name"]
|
||||
img_name, _ = os.path.splitext(file_name)
|
||||
image_path = os.path.join(coco_dir, file_name)
|
||||
if not os.path.isfile(image_path):
|
||||
print("{}/{}: {} is in annotations but not exist".format(ind + 1, images_num, image_path))
|
||||
continue
|
||||
if not os.path.exists(os.path.join(save_dir, img_name)):
|
||||
os.makedirs(os.path.join(save_dir, img_name))
|
||||
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = coco.loadAnns(anno_ids)
|
||||
h = coco.imgs[img_id]["height"]
|
||||
w = coco.imgs[img_id]["width"]
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
for instance in anno:
|
||||
m = annToMask(instance, h, w)
|
||||
c = coco_cls_dict[classs_dict[instance["category_id"]]]
|
||||
if len(m.shape) < 3:
|
||||
mask[:, :] += (mask == 0) * (m * c)
|
||||
else:
|
||||
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
|
||||
img = cv2.imread(image_path)
|
||||
cv2.imwrite(os.path.join(save_dir, img_name, "image.png"), img)
|
||||
cv2.imwrite(os.path.join(save_dir, img_name, "mask.png"), mask)
|
||||
|
||||
def preprocess_dataset(cfg, data_dir):
|
||||
"""Select preprocess function."""
|
||||
if cfg['dataset'].lower() == "cell_nuclei":
|
||||
preprocess_cell_nuclei_dataset({"data_dir": data_dir})
|
||||
elif cfg['dataset'].lower() == "coco":
|
||||
if 'split' in cfg and cfg['split'] == 1.0:
|
||||
train_data_path = os.path.join(data_dir, "train")
|
||||
val_data_path = os.path.join(data_dir, "val")
|
||||
train_param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"],
|
||||
"coco_dir": cfg["coco_dir"], "save_dir": train_data_path}
|
||||
preprocess_coco_dataset(train_param_dict)
|
||||
val_param_dict = {"anno_json": cfg["val_anno_json"], "coco_classes": cfg["coco_classes"],
|
||||
"coco_dir": cfg["val_coco_dir"], "save_dir": val_data_path}
|
||||
preprocess_coco_dataset(val_param_dict)
|
||||
else:
|
||||
param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"],
|
||||
"coco_dir": cfg["coco_dir"], "save_dir": data_dir}
|
||||
preprocess_coco_dataset(param_dict)
|
||||
else:
|
||||
raise ValueError("Not support dataset mode {}".format(cfg['dataset']))
|
||||
print("========== end preprocess dataset ==========")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
|
||||
help='save data directory')
|
||||
args = parser.parse_args()
|
||||
preprocess_dataset(cfg_unet, args.data_url)
|
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,6 +14,14 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
|
@ -24,10 +32,11 @@ then
|
|||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
export RANK_SIZE=8
|
||||
|
||||
DATASET=$(get_real_path $2)
|
||||
export RANK_TABLE_FILE=$(get_real_path $1)
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
rm -rf LOG$i
|
||||
|
@ -35,16 +44,15 @@ do
|
|||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_TABLE_FILE=$1
|
||||
export RANK_SIZE=8
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python3 train.py \
|
||||
python3 ${PROJECT_DIR}/../train.py \
|
||||
--run_distribute=True \
|
||||
--data_url=$2 > log.txt 2>&1 &
|
||||
--data_url=$DATASET > log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
done
|
||||
done
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,14 +14,31 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
|
||||
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
|
||||
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [DEVICE_ID](option, default is 0)"
|
||||
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ 0"
|
||||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export DEVICE_ID=0
|
||||
python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 &
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
export DEVICE_ID=$3
|
||||
fi
|
||||
DATASET=$(get_real_path $1)
|
||||
CHECKPOINT=$(get_real_path $2)
|
||||
echo "========== start run evaluation ==========="
|
||||
echo "please get log at eval.log"
|
||||
python ${PROJECT_DIR}/../eval.py --data_url=$DATASET --ckpt_path=$CHECKPOINT > eval.log 2>&1 &
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,14 +14,31 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_standalone_train.sh [DATASET]"
|
||||
echo "for example: bash run_standalone_train.sh /path/to/data/"
|
||||
echo "bash scripts/run_standalone_train.sh [DATASET] [DEVICE_ID](option, default is 0)"
|
||||
echo "for example: bash run_standalone_train.sh /path/to/data/ 0"
|
||||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export DEVICE_ID=0
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
export DEVICE_ID=$2
|
||||
fi
|
||||
|
||||
export DEVICE_ID=0
|
||||
python train.py --data_url=$1 > train.log 2>&1 &
|
||||
DATASET=$(get_real_path $1)
|
||||
echo "========== start run training ==========="
|
||||
echo "please get log at train.log"
|
||||
python ${PROJECT_DIR}/../train.py --data_url=$DATASET > train.log 2>&1 &
|
||||
|
|
|
@ -124,6 +124,54 @@ cfg_unet_simple = {
|
|||
'eval_resize': False
|
||||
}
|
||||
|
||||
cfg_unet_simple_coco = {
|
||||
'model': 'unet_simple',
|
||||
'dataset': 'COCO',
|
||||
'split': 1.0,
|
||||
'img_size': [512, 512],
|
||||
'lr': 3e-4,
|
||||
'epochs': 80,
|
||||
'repeat': 1,
|
||||
'distribute_epochs': 120,
|
||||
'cross_valid_ind': 1,
|
||||
'batchsize': 16,
|
||||
'num_channels': 3,
|
||||
|
||||
'keep_checkpoint_max': 10,
|
||||
'weight_decay': 0.0005,
|
||||
'loss_scale': 1024.0,
|
||||
'FixedLossScaleManager': 1024.0,
|
||||
|
||||
'resume': False,
|
||||
'resume_ckpt': './',
|
||||
'transfer_training': False,
|
||||
'filter_weight': ["final.weight"],
|
||||
'eval_activate': 'Softmax',
|
||||
'eval_resize': False,
|
||||
|
||||
'num_classes': 81,
|
||||
'coco_classes': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
# change the following settings to real path
|
||||
'anno_json': '/data/coco2017/annotations/instances_train2017.json',
|
||||
'val_anno_json': '/data/coco2017/annotations/instances_val2017.json',
|
||||
'coco_dir': '/data/coco2017/train2017',
|
||||
'val_coco_dir': '/data/coco2017/val2017'
|
||||
}
|
||||
|
||||
cfg_unet = cfg_unet_simple
|
||||
if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']:
|
||||
print("ISBI dataset not support resize to original image size when in evaluation.")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -165,38 +165,33 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
|
|||
|
||||
return train_ds, valid_ds
|
||||
|
||||
class CellNucleiDataset:
|
||||
class MultiClassDataset:
|
||||
"""
|
||||
Cell nuclei dataset preprocess class.
|
||||
Read image and mask from original images, and split all data into train_dataset and val_dataset by `split`.
|
||||
Get image path and mask path from a tree of directories,
|
||||
images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
|
||||
"""
|
||||
def __init__(self, data_dir, repeat, is_train=False, split=0.8):
|
||||
def __init__(self, data_dir, repeat, is_train=False, split=0.8, shuffle=False):
|
||||
self.data_dir = data_dir
|
||||
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
|
||||
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
|
||||
np.random.shuffle(self.train_ids)
|
||||
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
|
||||
self.is_train = is_train
|
||||
self._preprocess_dataset()
|
||||
|
||||
def _preprocess_dataset(self):
|
||||
for img_id in self.img_ids:
|
||||
path = os.path.join(self.data_dir, img_id)
|
||||
if (not os.path.exists(os.path.join(path, "image.png"))) or \
|
||||
(not os.path.exists(os.path.join(path, "mask.png"))):
|
||||
img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
|
||||
if len(img.shape) == 2:
|
||||
img = np.expand_dims(img, axis=-1)
|
||||
img = np.concatenate([img, img, img], axis=-1)
|
||||
mask = []
|
||||
for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
|
||||
mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
|
||||
mask.append(mask_)
|
||||
mask = np.max(mask, axis=0)
|
||||
cv2.imwrite(os.path.join(path, "image.png"), img)
|
||||
cv2.imwrite(os.path.join(path, "mask.png"), mask)
|
||||
self.split = (split != 1.0)
|
||||
if self.split:
|
||||
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
|
||||
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
|
||||
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
|
||||
else:
|
||||
self.train_ids = sorted(next(os.walk(os.path.join(self.data_dir, "train")))[1])
|
||||
self.val_ids = sorted(next(os.walk(os.path.join(self.data_dir, "val")))[1])
|
||||
if shuffle:
|
||||
np.random.shuffle(self.train_ids)
|
||||
|
||||
def _read_img_mask(self, img_id):
|
||||
path = os.path.join(self.data_dir, img_id)
|
||||
if self.split:
|
||||
path = os.path.join(self.data_dir, img_id)
|
||||
elif self.is_train:
|
||||
path = os.path.join(self.data_dir, "train", img_id)
|
||||
else:
|
||||
path = os.path.join(self.data_dir, "val", img_id)
|
||||
img = cv2.imread(os.path.join(path, "image.png"))
|
||||
mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
|
||||
return img, mask
|
||||
|
@ -216,9 +211,9 @@ class CellNucleiDataset:
|
|||
return len(self.train_ids)
|
||||
return len(self.val_ids)
|
||||
|
||||
def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False):
|
||||
def preprocess_img_mask(img, mask, num_classes, img_size, augment=False, eval_resize=False):
|
||||
"""
|
||||
Preprocess for cell nuclei dataset.
|
||||
Preprocess for multi-class dataset.
|
||||
Random crop and flip images and masks when augment is True.
|
||||
"""
|
||||
if augment:
|
||||
|
@ -240,24 +235,28 @@ def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False):
|
|||
mask = cv2.resize(mask, img_size)
|
||||
img = (img.astype(np.float32) - 127.5) / 127.5
|
||||
img = img.transpose(2, 0, 1)
|
||||
mask = mask.astype(np.float32) / 255
|
||||
mask = (mask > 0.5).astype(np.int)
|
||||
mask = (np.arange(2) == mask[..., None]).astype(int)
|
||||
if num_classes == 2:
|
||||
mask = mask.astype(np.float32) / mask.max()
|
||||
mask = (mask > 0.5).astype(np.int)
|
||||
else:
|
||||
mask = mask.astype(np.int)
|
||||
mask = (np.arange(num_classes) == mask[..., None]).astype(int)
|
||||
mask = mask.transpose(2, 0, 1).astype(np.float32)
|
||||
return img, mask
|
||||
|
||||
def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, eval_resize=False,
|
||||
split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8):
|
||||
def create_multi_class_dataset(data_dir, img_size, repeat, batch_size, num_classes=2, is_train=False, augment=False,
|
||||
eval_resize=False, split=0.8, rank=0, group_size=1, python_multiprocessing=True,
|
||||
num_parallel_workers=8, shuffle=True):
|
||||
"""
|
||||
Get generator dataset for cell nuclei dataset.
|
||||
Get generator dataset for multi-class dataset.
|
||||
"""
|
||||
cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split)
|
||||
sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train)
|
||||
dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler)
|
||||
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train,
|
||||
eval_resize))
|
||||
dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names,
|
||||
output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names,
|
||||
mc_dataset = MultiClassDataset(data_dir, repeat, is_train, split, shuffle)
|
||||
sampler = ds.DistributedSampler(group_size, rank, shuffle=shuffle)
|
||||
dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, sampler=sampler)
|
||||
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, num_classes, tuple(img_size),
|
||||
augment and is_train, eval_resize))
|
||||
dataset = dataset.map(operations=compose_map_func, input_columns=mc_dataset.column_names,
|
||||
output_columns=mc_dataset.column_names, column_order=mc_dataset.column_names,
|
||||
python_multiprocessing=python_multiprocessing,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
dataset = dataset.batch(batch_size, drop_remainder=is_train)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -28,7 +28,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
|
||||
from src.unet_medical import UNetMedical
|
||||
from src.unet_nested import NestedUNet, UNet
|
||||
from src.data_loader import create_dataset, create_cell_nuclei_dataset
|
||||
from src.data_loader import create_dataset, create_multi_class_dataset
|
||||
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
|
||||
from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff
|
||||
from src.config import cfg_unet
|
||||
|
@ -79,16 +79,18 @@ def train_net(args_opt,
|
|||
criterion = MultiCrossEntropyWithLogits()
|
||||
else:
|
||||
criterion = CrossEntropyWithLogits()
|
||||
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
|
||||
repeat = cfg['repeat']
|
||||
if 'dataset' in cfg and cfg['dataset'] != "ISBI":
|
||||
repeat = cfg['repeat'] if 'repeat' in cfg else 1
|
||||
split = cfg['split'] if 'split' in cfg else 0.8
|
||||
dataset_sink_mode = True
|
||||
per_print_times = 0
|
||||
train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size,
|
||||
is_train=True, augment=True, split=0.8, rank=rank,
|
||||
group_size=group_size)
|
||||
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
|
||||
eval_resize=cfg["eval_resize"], split=0.8,
|
||||
python_multiprocessing=False)
|
||||
train_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], repeat, batch_size,
|
||||
num_classes=cfg['num_classes'], is_train=True, augment=True,
|
||||
split=split, rank=rank, group_size=group_size, shuffle=True)
|
||||
valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1,
|
||||
num_classes=cfg['num_classes'], is_train=False,
|
||||
eval_resize=cfg["eval_resize"], split=split,
|
||||
python_multiprocessing=False, shuffle=False)
|
||||
else:
|
||||
repeat = cfg['repeat']
|
||||
dataset_sink_mode = False
|
||||
|
|
Loading…
Reference in New Issue