add coco convert support for unet

This commit is contained in:
zhaoting 2021-05-07 14:39:21 +08:00
parent 7a3d9f2ad7
commit 2995394064
10 changed files with 413 additions and 77 deletions

View File

@ -53,7 +53,58 @@ Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home)
- Data formatbinary files(TIF file)
- NoteData will be processed in src/data_loader.py
We also support cell nuclei dataset which is used in [Unet++ original paper](https://arxiv.org/abs/1912.05074). If you want to use the dataset, please add `'dataset': 'Cell_nuclei'` in `src/config.py`.
We also support Multi-Class dataset which get image path and mask path from a tree of directories.
Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
The directory structure is as follows:
```path
.
└─dataset
└─0001
├─image.png
└─mask.png
└─0002
├─image.png
└─mask.png
...
└─xxxx
├─image.png
└─mask.png
```
When you set `split` in (0, 1) in config, all images will be split to train dataset and val dataset by split value, and the `split` default is 0.8.
If set `split`=1.0, you should split train dataset and val dataset by directories, the directory structure is as follows:
```path
.
└─dataset
└─train
└─0001
├─image.png
└─mask.png
...
└─xxxx
├─image.png
└─mask.png
└─val
└─0001
├─image.png
└─mask.png
...
└─xxxx
├─image.png
└─mask.png
```
We support script to convert COCO and a Cell_Nuclei dataset used in used in [Unet++ original paper](https://arxiv.org/abs/1912.05074) to mulyi-class dataset format.
1. Change `cfg_unet` in `src/config.py`, you can refer to `cfg_unet_nested_cell` and `cfg_unet_simple_coco` in `src/config.py` for detail.
2. run script to convert to mulyi-class dataset format:
```shell
python preprocess_dataset.py -d /data/save_data_path
```
## [Environment Requirements](#contents)
@ -145,6 +196,7 @@ Then you can run everything just like on ascend.
├── mindspore_hub_conf.py // hub config file
├── postprocess.py // unet 310 infer postprocess.
├── preprocess.py // unet 310 infer preprocess dataset
├── preprocess_dataset.py // the script to adapt MultiClass dataset
├── requirements.txt // Requirements of third party package.
```

View File

@ -57,7 +57,58 @@ UNet++是U-Net的增强版本使用了新的跨层链接方式和深层监督
- 数据格式二进制文件TIF
- 注意数据在src/data_loader.py中处理
我们也支持一个在 [Unet++](https://arxiv.org/abs/1912.05074) 原论文中使用的数据集 `Cell_nuclei`。可以通过修改`src/config.py`中`'dataset': 'Cell_nuclei'`配置使用.
我们也支持一种 Multi-Class 数据集格式,通过固定的目录结构获取图片和对应标签数据。
在同一个目录中保存原图片及对应标签,其中图片名为 `"image.png"`,标签名为 `"mask.png"`
目录结构如下:
```path
.
└─dataset
└─0001
├─image.png
└─mask.png
└─0002
├─image.png
└─mask.png
...
└─xxxx
├─image.png
└─mask.png
```
通过在`config`中的`split`参数将所有的图片分为训练集和验证集,`split` 默认为 0.8。
当设置 `split`为 1.0时,通过目录来分训练集和验证集,目录结构如下:
```path
.
└─dataset
└─train
└─0001
├─image.png
└─mask.png
...
└─xxxx
├─image.png
└─mask.png
└─val
└─0001
├─image.png
└─mask.png
...
└─xxxx
├─image.png
└─mask.png
```
我们提供了一个脚本来将 COCO 和 Cell_Nuclei 数据集([Unet++ 原论文](https://arxiv.org/abs/1912.05074) 中使用转换为multi-class格式。
1. 在`src/config.py`中修改`cfg_unet`,修改细节请参考`src/config.py`中的`cfg_unet_nested_cell` 和 `cfg_unet_simple_coco`
2. 运行转换脚本:
```shell
python preprocess_dataset.py -d /data/save_data_path
```
## 环境要求
@ -149,6 +200,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
├── mindspore_hub_conf.py // hub 配置脚本
├── postprocess.py // 310 推理后处理脚本
├── preprocess.py // 310 推理前处理脚本
├── preprocess_dataset.py // 适配MultiClass数据集脚本
├── requirements.txt // 需要的三方库.
```

View File

@ -19,7 +19,7 @@ import logging
from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.data_loader import create_dataset, create_multi_class_dataset
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet
@ -44,9 +44,12 @@ def test_net(data_dir,
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net = UnetEval(net)
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
eval_resize=cfg["eval_resize"], split=0.8)
if 'dataset' in cfg and cfg['dataset'] != "ISBI":
split = cfg['split'] if 'split' in cfg else 0.8
valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1,
num_classes=cfg['num_classes'], is_train=False,
eval_resize=cfg["eval_resize"], split=split,
python_multiprocessing=False, shuffle=False)
else:
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])

View File

@ -0,0 +1,138 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Preprocess dataset.
Images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
"""
import os
import argparse
import cv2
import numpy as np
from model_zoo.official.cv.unet.src.config import cfg_unet
def annToMask(ann, height, width):
"""Convert annotation to RLE and then to binary mask."""
from pycocotools import mask as maskHelper
segm = ann['segmentation']
if isinstance(segm, list):
rles = maskHelper.frPyObjects(segm, height, width)
rle = maskHelper.merge(rles)
elif isinstance(segm['counts'], list):
rle = maskHelper.frPyObjects(segm, height, width)
else:
rle = ann['segmentation']
m = maskHelper.decode(rle)
return m
def preprocess_cell_nuclei_dataset(param_dict):
"""
Preprocess for Cell Nuclei dataset.
merge all instances to a mask, and save the mask at data_dir/img_id/mask.png.
"""
print("========== start preprocess Cell Nuclei dataset ==========")
data_dir = param_dict["data_dir"]
img_ids = sorted(next(os.walk(data_dir))[1])
for img_id in img_ids:
path = os.path.join(data_dir, img_id)
if (not os.path.exists(os.path.join(path, "image.png"))) or \
(not os.path.exists(os.path.join(path, "mask.png"))):
img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
if len(img.shape) == 2:
img = np.expand_dims(img, axis=-1)
img = np.concatenate([img, img, img], axis=-1)
mask = []
for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
mask.append(mask_)
mask = np.max(mask, axis=0)
cv2.imwrite(os.path.join(path, "image.png"), img)
cv2.imwrite(os.path.join(path, "mask.png"), mask)
def preprocess_coco_dataset(param_dict):
"""
Preprocess for coco dataset.
Save image and mask at save_dir/img_name/image.png save_dir/img_name/mask.png
"""
print("========== start preprocess coco dataset ==========")
from pycocotools.coco import COCO
anno_json = param_dict["anno_json"]
coco_cls = param_dict["coco_classes"]
coco_dir = param_dict["coco_dir"]
save_dir = param_dict["save_dir"]
coco_cls_dict = {}
for i, cls in enumerate(coco_cls):
coco_cls_dict[cls] = i
coco = COCO(anno_json)
classs_dict = {}
cat_ids = coco.loadCats(coco.getCatIds())
for cat in cat_ids:
classs_dict[cat["id"]] = cat["name"]
image_ids = coco.getImgIds()
images_num = len(image_ids)
for ind, img_id in enumerate(image_ids):
image_info = coco.loadImgs(img_id)
file_name = image_info[0]["file_name"]
img_name, _ = os.path.splitext(file_name)
image_path = os.path.join(coco_dir, file_name)
if not os.path.isfile(image_path):
print("{}/{}: {} is in annotations but not exist".format(ind + 1, images_num, image_path))
continue
if not os.path.exists(os.path.join(save_dir, img_name)):
os.makedirs(os.path.join(save_dir, img_name))
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = coco.loadAnns(anno_ids)
h = coco.imgs[img_id]["height"]
w = coco.imgs[img_id]["width"]
mask = np.zeros((h, w), dtype=np.uint8)
for instance in anno:
m = annToMask(instance, h, w)
c = coco_cls_dict[classs_dict[instance["category_id"]]]
if len(m.shape) < 3:
mask[:, :] += (mask == 0) * (m * c)
else:
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
img = cv2.imread(image_path)
cv2.imwrite(os.path.join(save_dir, img_name, "image.png"), img)
cv2.imwrite(os.path.join(save_dir, img_name, "mask.png"), mask)
def preprocess_dataset(cfg, data_dir):
"""Select preprocess function."""
if cfg['dataset'].lower() == "cell_nuclei":
preprocess_cell_nuclei_dataset({"data_dir": data_dir})
elif cfg['dataset'].lower() == "coco":
if 'split' in cfg and cfg['split'] == 1.0:
train_data_path = os.path.join(data_dir, "train")
val_data_path = os.path.join(data_dir, "val")
train_param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"],
"coco_dir": cfg["coco_dir"], "save_dir": train_data_path}
preprocess_coco_dataset(train_param_dict)
val_param_dict = {"anno_json": cfg["val_anno_json"], "coco_classes": cfg["coco_classes"],
"coco_dir": cfg["val_coco_dir"], "save_dir": val_data_path}
preprocess_coco_dataset(val_param_dict)
else:
param_dict = {"anno_json": cfg["anno_json"], "coco_classes": cfg["coco_classes"],
"coco_dir": cfg["coco_dir"], "save_dir": data_dir}
preprocess_coco_dataset(param_dict)
else:
raise ValueError("Not support dataset mode {}".format(cfg['dataset']))
print("========== end preprocess dataset ==========")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='save data directory')
args = parser.parse_args()
preprocess_dataset(cfg_unet, args.data_url)

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,6 +14,14 @@
# limitations under the License.
# ============================================================================
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 2 ]
then
echo "=============================================================================================================="
@ -24,10 +32,11 @@ then
echo "=============================================================================================================="
exit 1
fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export HCCL_CONNECT_TIMEOUT=600
export RANK_SIZE=8
DATASET=$(get_real_path $2)
export RANK_TABLE_FILE=$(get_real_path $1)
for((i=0;i<RANK_SIZE;i++))
do
rm -rf LOG$i
@ -35,16 +44,15 @@ do
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cd ./LOG$i || exit
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export RANK_ID=$i
export DEVICE_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
python3 train.py \
python3 ${PROJECT_DIR}/../train.py \
--run_distribute=True \
--data_url=$2 > log.txt 2>&1 &
--data_url=$DATASET > log.txt 2>&1 &
cd ../
done
done

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,14 +14,31 @@
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 2 ] && [ $# != 3 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [DEVICE_ID](option, default is 0)"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/ 0"
echo "=============================================================================================================="
exit 1
fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export DEVICE_ID=0
python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 &
if [ $# != 2 ]
then
export DEVICE_ID=$3
fi
DATASET=$(get_real_path $1)
CHECKPOINT=$(get_real_path $2)
echo "========== start run evaluation ==========="
echo "please get log at eval.log"
python ${PROJECT_DIR}/../eval.py --data_url=$DATASET --ckpt_path=$CHECKPOINT > eval.log 2>&1 &

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,14 +14,31 @@
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 1 ] && [ $# != 2 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train.sh [DATASET]"
echo "for example: bash run_standalone_train.sh /path/to/data/"
echo "bash scripts/run_standalone_train.sh [DATASET] [DEVICE_ID](option, default is 0)"
echo "for example: bash run_standalone_train.sh /path/to/data/ 0"
echo "=============================================================================================================="
exit 1
fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export DEVICE_ID=0
if [ $# != 1 ]
then
export DEVICE_ID=$2
fi
export DEVICE_ID=0
python train.py --data_url=$1 > train.log 2>&1 &
DATASET=$(get_real_path $1)
echo "========== start run training ==========="
echo "please get log at train.log"
python ${PROJECT_DIR}/../train.py --data_url=$DATASET > train.log 2>&1 &

View File

@ -124,6 +124,54 @@ cfg_unet_simple = {
'eval_resize': False
}
cfg_unet_simple_coco = {
'model': 'unet_simple',
'dataset': 'COCO',
'split': 1.0,
'img_size': [512, 512],
'lr': 3e-4,
'epochs': 80,
'repeat': 1,
'distribute_epochs': 120,
'cross_valid_ind': 1,
'batchsize': 16,
'num_channels': 3,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ["final.weight"],
'eval_activate': 'Softmax',
'eval_resize': False,
'num_classes': 81,
'coco_classes': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
# change the following settings to real path
'anno_json': '/data/coco2017/annotations/instances_train2017.json',
'val_anno_json': '/data/coco2017/annotations/instances_val2017.json',
'coco_dir': '/data/coco2017/train2017',
'val_coco_dir': '/data/coco2017/val2017'
}
cfg_unet = cfg_unet_simple
if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']:
print("ISBI dataset not support resize to original image size when in evaluation.")

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -165,38 +165,33 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
return train_ds, valid_ds
class CellNucleiDataset:
class MultiClassDataset:
"""
Cell nuclei dataset preprocess class.
Read image and mask from original images, and split all data into train_dataset and val_dataset by `split`.
Get image path and mask path from a tree of directories,
images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
"""
def __init__(self, data_dir, repeat, is_train=False, split=0.8):
def __init__(self, data_dir, repeat, is_train=False, split=0.8, shuffle=False):
self.data_dir = data_dir
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
np.random.shuffle(self.train_ids)
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
self.is_train = is_train
self._preprocess_dataset()
def _preprocess_dataset(self):
for img_id in self.img_ids:
path = os.path.join(self.data_dir, img_id)
if (not os.path.exists(os.path.join(path, "image.png"))) or \
(not os.path.exists(os.path.join(path, "mask.png"))):
img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
if len(img.shape) == 2:
img = np.expand_dims(img, axis=-1)
img = np.concatenate([img, img, img], axis=-1)
mask = []
for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
mask.append(mask_)
mask = np.max(mask, axis=0)
cv2.imwrite(os.path.join(path, "image.png"), img)
cv2.imwrite(os.path.join(path, "mask.png"), mask)
self.split = (split != 1.0)
if self.split:
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
else:
self.train_ids = sorted(next(os.walk(os.path.join(self.data_dir, "train")))[1])
self.val_ids = sorted(next(os.walk(os.path.join(self.data_dir, "val")))[1])
if shuffle:
np.random.shuffle(self.train_ids)
def _read_img_mask(self, img_id):
path = os.path.join(self.data_dir, img_id)
if self.split:
path = os.path.join(self.data_dir, img_id)
elif self.is_train:
path = os.path.join(self.data_dir, "train", img_id)
else:
path = os.path.join(self.data_dir, "val", img_id)
img = cv2.imread(os.path.join(path, "image.png"))
mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
return img, mask
@ -216,9 +211,9 @@ class CellNucleiDataset:
return len(self.train_ids)
return len(self.val_ids)
def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False):
def preprocess_img_mask(img, mask, num_classes, img_size, augment=False, eval_resize=False):
"""
Preprocess for cell nuclei dataset.
Preprocess for multi-class dataset.
Random crop and flip images and masks when augment is True.
"""
if augment:
@ -240,24 +235,28 @@ def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False):
mask = cv2.resize(mask, img_size)
img = (img.astype(np.float32) - 127.5) / 127.5
img = img.transpose(2, 0, 1)
mask = mask.astype(np.float32) / 255
mask = (mask > 0.5).astype(np.int)
mask = (np.arange(2) == mask[..., None]).astype(int)
if num_classes == 2:
mask = mask.astype(np.float32) / mask.max()
mask = (mask > 0.5).astype(np.int)
else:
mask = mask.astype(np.int)
mask = (np.arange(num_classes) == mask[..., None]).astype(int)
mask = mask.transpose(2, 0, 1).astype(np.float32)
return img, mask
def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, eval_resize=False,
split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8):
def create_multi_class_dataset(data_dir, img_size, repeat, batch_size, num_classes=2, is_train=False, augment=False,
eval_resize=False, split=0.8, rank=0, group_size=1, python_multiprocessing=True,
num_parallel_workers=8, shuffle=True):
"""
Get generator dataset for cell nuclei dataset.
Get generator dataset for multi-class dataset.
"""
cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split)
sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train)
dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler)
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train,
eval_resize))
dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names,
output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names,
mc_dataset = MultiClassDataset(data_dir, repeat, is_train, split, shuffle)
sampler = ds.DistributedSampler(group_size, rank, shuffle=shuffle)
dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, sampler=sampler)
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, num_classes, tuple(img_size),
augment and is_train, eval_resize))
dataset = dataset.map(operations=compose_map_func, input_columns=mc_dataset.column_names,
output_columns=mc_dataset.column_names, column_order=mc_dataset.column_names,
python_multiprocessing=python_multiprocessing,
num_parallel_workers=num_parallel_workers)
dataset = dataset.batch(batch_size, drop_remainder=is_train)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -28,7 +28,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.data_loader import create_dataset, create_multi_class_dataset
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff
from src.config import cfg_unet
@ -79,16 +79,18 @@ def train_net(args_opt,
criterion = MultiCrossEntropyWithLogits()
else:
criterion = CrossEntropyWithLogits()
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
repeat = cfg['repeat']
if 'dataset' in cfg and cfg['dataset'] != "ISBI":
repeat = cfg['repeat'] if 'repeat' in cfg else 1
split = cfg['split'] if 'split' in cfg else 0.8
dataset_sink_mode = True
per_print_times = 0
train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size,
is_train=True, augment=True, split=0.8, rank=rank,
group_size=group_size)
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
eval_resize=cfg["eval_resize"], split=0.8,
python_multiprocessing=False)
train_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], repeat, batch_size,
num_classes=cfg['num_classes'], is_train=True, augment=True,
split=split, rank=rank, group_size=group_size, shuffle=True)
valid_dataset = create_multi_class_dataset(data_dir, cfg['img_size'], 1, 1,
num_classes=cfg['num_classes'], is_train=False,
eval_resize=cfg["eval_resize"], split=split,
python_multiprocessing=False, shuffle=False)
else:
repeat = cfg['repeat']
dataset_sink_mode = False