!12685 Add nested-unet

From: @c_34
Reviewed-by: @wuxuejian
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-03-02 11:26:05 +08:00 committed by Gitee
commit b9fac815bc
16 changed files with 371 additions and 40 deletions

View File

@ -25,7 +25,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss.loss import _Loss
from src.data_loader import create_dataset
from src.unet import UNet
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet
from scipy.special import softmax
@ -34,8 +35,6 @@ device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
class CrossEntropyWithLogits(_Loss):
def __init__(self):
super(CrossEntropyWithLogits, self).__init__()
@ -64,10 +63,11 @@ class dice_coeff(nn.Metric):
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Mean dice coeffcient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
y_pred = y_pred.transpose(0, 2, 3, 1)
y = y.transpose(0, 2, 3, 1)
@ -90,13 +90,20 @@ def test_net(data_dir,
ckpt_path,
cross_valid_ind=1,
cfg=None):
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
raise ValueError("Unsupported model: {}".format(cfg['model']))
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
criterion = CrossEntropyWithLogits()
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])
model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()})
print("============== Starting Evaluating ============")

View File

@ -18,7 +18,8 @@ import numpy as np
from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context
from src.unet.unet_model import UNet
from src.unet_medical.unet_model import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet as cfg
parser = argparse.ArgumentParser(description='unet export')
@ -38,7 +39,14 @@ if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == "__main__":
net = UNet(n_channels=cfg["num_channels"], n_classes=cfg["num_classes"])
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
raise ValueError("Unsupported model: {}".format(cfg['model']))
# return a parameter dict for model
param_dict = load_checkpoint(args.ckpt_file)
# load the parameter into net

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""hub config."""
from src.unet import UNet
from src.unet_medical import UNet
def create_network(name, *args, **kwargs):
if name == "unet2d":

View File

@ -14,15 +14,14 @@
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data"
echo "=============================================================================================================="
if [ $# != 2 ]
then
echo "=============================================================================================================="
echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
echo "Please run the script as: "
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data"
echo "=============================================================================================================="
exit 1
fi

View File

@ -14,11 +14,14 @@
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "=============================================================================================================="
if [ $# != 2 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "=============================================================================================================="
fi
export DEVICE_ID=0
python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 &

View File

@ -14,11 +14,14 @@
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train.sh [DATASET]"
echo "for example: bash run_standalone_train.sh /path/to/data/"
echo "=============================================================================================================="
if [ $# != 1 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train.sh [DATASET]"
echo "for example: bash run_standalone_train.sh /path/to/data/"
echo "=============================================================================================================="
fi
export DEVICE_ID=0
python train.py --data_url=$1 > train.log 2>&1 &

View File

@ -13,7 +13,10 @@
# limitations under the License.
# ============================================================================
cfg_unet = {
cfg_unet_medical = {
'model': 'unet_medical',
'crop': [388 / 572, 388 / 572],
'img_size': [572, 572],
'lr': 0.0001,
'epochs': 400,
'distribute_epochs': 1600,
@ -30,3 +33,47 @@ cfg_unet = {
'resume': False,
'resume_ckpt': './',
}
cfg_unet_nested = {
'model': 'unet_nested',
'crop': None,
'img_size': [576, 576],
'lr': 0.0001,
'epochs': 400,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 1,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
}
cfg_unet_simple = {
'model': 'unet_simple',
'crop': None,
'img_size': [576, 576],
'lr': 0.0001,
'epochs': 400,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 1,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
}
cfg_unet = cfg_unet_medical

View File

@ -82,7 +82,8 @@ def train_data_augmentation(img, mask):
return img, mask
def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False):
def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False,
do_crop=None, img_size=None):
images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif'))
masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif'))
@ -121,8 +122,12 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False)
ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False)
c_resize_op = c_vision.Resize(size=(388, 388), interpolation=Inter.BILINEAR)
c_pad = c_vision.Pad(padding=92)
if do_crop:
resize_size = [int(img_size[x] * do_crop[x]) for x in range(len(img_size))]
else:
resize_size = img_size
c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR)
c_pad = c_vision.Pad(padding=(img_size[0] - resize_size[0]) // 2)
c_rescale_image = c_vision.Rescale(1.0/127.5, -1)
c_rescale_mask = c_vision.Rescale(1.0/255.0, 0)
@ -136,12 +141,13 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
train_ds = train_ds.project(columns=["image", "mask"])
if augment:
augment_process = train_data_augmentation
c_resize_op = c_vision.Resize(size=(572, 572), interpolation=Inter.BILINEAR)
c_resize_op = c_vision.Resize(size=(img_size[0], img_size[1]), interpolation=Inter.BILINEAR)
train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process)
train_ds = train_ds.map(input_columns="image", operations=c_resize_op)
train_ds = train_ds.map(input_columns="mask", operations=c_resize_op)
train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
if do_crop:
train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process
train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process)
train_ds = train_ds.shuffle(repeat*24)
@ -151,7 +157,8 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
valid_ds = ds.zip((valid_image_ds, valid_mask_ds))
valid_ds = valid_ds.project(columns=["image", "mask"])
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
if do_crop:
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process
valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process)
valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True)

View File

@ -13,4 +13,4 @@
# limitations under the License.
# ============================================================================
from .unet_model import UNet
from .unet_model import UNetMedical

View File

@ -13,12 +13,12 @@
# limitations under the License.
# ============================================================================
from src.unet.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv
from src.unet_medical.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv
import mindspore.nn as nn
class UNet(nn.Cell):
class UNetMedical(nn.Cell):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
super(UNetMedical, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.inc = DoubleConv(n_channels, 64)

View File

@ -0,0 +1,16 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .unet_model import NestedUNet, UNet

View File

@ -0,0 +1,146 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Model of UnetPlusPlus
import mindspore.nn as nn
from .unet_parts import UnetConv2d, UnetUp
class NestedUNet(nn.Cell):
"""
Nested unet
"""
def __init__(self, in_channel, n_class=2, feature_scale=2, use_deconv=True, use_bn=True, use_ds=True):
super(NestedUNet, self).__init__()
self.in_channel = in_channel
self.n_class = n_class
self.feature_scale = feature_scale
self.use_deconv = use_deconv
self.use_bn = use_bn
self.use_ds = use_ds
filters = [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]
# Down Sample
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.conv00 = UnetConv2d(self.in_channel, filters[0], self.use_bn)
self.conv10 = UnetConv2d(filters[0], filters[1], self.use_bn)
self.conv20 = UnetConv2d(filters[1], filters[2], self.use_bn)
self.conv30 = UnetConv2d(filters[2], filters[3], self.use_bn)
self.conv40 = UnetConv2d(filters[3], filters[4], self.use_bn)
# Up Sample
self.up_concat01 = UnetUp(filters[1], filters[0], self.use_deconv, 2)
self.up_concat11 = UnetUp(filters[2], filters[1], self.use_deconv, 2)
self.up_concat21 = UnetUp(filters[3], filters[2], self.use_deconv, 2)
self.up_concat31 = UnetUp(filters[4], filters[3], self.use_deconv, 2)
self.up_concat02 = UnetUp(filters[1], filters[0], self.use_deconv, 3)
self.up_concat12 = UnetUp(filters[2], filters[1], self.use_deconv, 3)
self.up_concat22 = UnetUp(filters[3], filters[2], self.use_deconv, 3)
self.up_concat03 = UnetUp(filters[1], filters[0], self.use_deconv, 4)
self.up_concat13 = UnetUp(filters[2], filters[1], self.use_deconv, 4)
self.up_concat04 = UnetUp(filters[1], filters[0], self.use_deconv, 5)
# Finale Convolution
self.final1 = nn.Conv2d(filters[0], n_class, 1)
self.final2 = nn.Conv2d(filters[0], n_class, 1)
self.final3 = nn.Conv2d(filters[0], n_class, 1)
self.final4 = nn.Conv2d(filters[0], n_class, 1)
def construct(self, inputs):
x00 = self.conv00(inputs) # channel = filters[0]
x10 = self.conv10(self.maxpool(x00)) # channel = filters[1]
x20 = self.conv20(self.maxpool(x10)) # channel = filters[2]
x30 = self.conv30(self.maxpool(x20)) # channel = filters[3]
x40 = self.conv40(self.maxpool(x30)) # channel = filters[4]
x01 = self.up_concat01(x10, x00) # channel = filters[0]
x11 = self.up_concat11(x20, x10) # channel = filters[1]
x21 = self.up_concat21(x30, x20) # channel = filters[2]
x31 = self.up_concat31(x40, x30) # channel = filters[3]
x02 = self.up_concat02(x11, x00, x01) # channel = filters[0]
x12 = self.up_concat12(x21, x10, x11) # channel = filters[1]
x22 = self.up_concat22(x31, x20, x21) # channel = filters[2]
x03 = self.up_concat03(x12, x00, x01, x02) # channel = filters[0]
x13 = self.up_concat13(x22, x10, x11, x12) # channel = filters[1]
x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0]
final1 = self.final1(x01)
final2 = self.final1(x02)
final3 = self.final1(x03)
final4 = self.final1(x04)
final = (final1 + final2 + final3 + final4) / 4.0
if self.use_ds:
return final
return final4
class UNet(nn.Cell):
"""
Simple UNet with skip connection
"""
def __init__(self, in_channel, n_class=2, feature_scale=2, use_deconv=True, use_bn=True):
super(UNet, self).__init__()
self.in_channel = in_channel
self.n_class = n_class
self.feature_scale = feature_scale
self.use_deconv = use_deconv
self.use_bn = use_bn
filters = [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]
# Down Sample
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.conv0 = UnetConv2d(self.in_channel, filters[0], self.use_bn)
self.conv1 = UnetConv2d(filters[0], filters[1], self.use_bn)
self.conv2 = UnetConv2d(filters[1], filters[2], self.use_bn)
self.conv3 = UnetConv2d(filters[2], filters[3], self.use_bn)
self.conv4 = UnetConv2d(filters[3], filters[4], self.use_bn)
# Up Sample
self.up_concat1 = UnetUp(filters[1], filters[0], self.use_deconv, 2)
self.up_concat2 = UnetUp(filters[2], filters[1], self.use_deconv, 2)
self.up_concat3 = UnetUp(filters[3], filters[2], self.use_deconv, 2)
self.up_concat4 = UnetUp(filters[4], filters[3], self.use_deconv, 2)
# Finale Convolution
self.final = nn.Conv2d(filters[0], n_class, 1)
def construct(self, inputs):
x0 = self.conv0(inputs) # channel = filters[0]
x1 = self.conv1(self.maxpool(x0)) # channel = filters[1]
x2 = self.conv2(self.maxpool(x1)) # channel = filters[2]
x3 = self.conv3(self.maxpool(x2)) # channel = filters[3]
x4 = self.conv4(self.maxpool(x3)) # channel = filters[4]
up4 = self.up_concat4(x4, x3)
up3 = self.up_concat3(up4, x2)
up2 = self.up_concat2(up3, x1)
up1 = self.up_concat1(up2, x0)
final = self.final(up1)
return final

View File

@ -0,0 +1,81 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Parts of the U-Net-PlusPlus model """
import mindspore.nn as nn
import mindspore.ops.functional as F
import mindspore.ops.operations as P
def conv_bn_relu(in_channel, out_channel, use_bn=True, kernel_size=3, stride=1, pad_mode="same", activation='relu'):
output = []
output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode=pad_mode))
if use_bn:
output.append(nn.BatchNorm2d(out_channel))
if activation:
output.append(nn.get_activation(activation))
return nn.SequentialCell(output)
class UnetConv2d(nn.Cell):
"""
Convolution block in Unet, usually double conv.
"""
def __init__(self, in_channel, out_channel, use_bn=True, num_layer=2, kernel_size=3, stride=1, padding='same'):
super(UnetConv2d, self).__init__()
self.num_layer = num_layer
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.in_channel = in_channel
self.out_channel = out_channel
convs = []
for _ in range(num_layer):
convs.append(conv_bn_relu(in_channel, out_channel, use_bn, kernel_size, stride, padding, "relu"))
in_channel = out_channel
self.convs = nn.SequentialCell(convs)
def construct(self, inputs):
x = self.convs(inputs)
return x
class UnetUp(nn.Cell):
"""
Upsampling high_feature with factor=2 and concat with low feature
"""
def __init__(self, in_channel, out_channel, use_deconv, n_concat=2):
super(UnetUp, self).__init__()
self.conv = UnetConv2d(in_channel + (n_concat - 2) * out_channel, out_channel, False)
self.concat = P.Concat(axis=1)
self.use_deconv = use_deconv
if use_deconv:
self.up_conv = nn.Conv2dTranspose(in_channel, out_channel, kernel_size=2, stride=2, pad_mode="same")
else:
self.up_conv = nn.Conv2d(in_channel, out_channel, 1)
def construct(self, high_feature, *low_feature):
if self.use_deconv:
output = self.up_conv(high_feature)
else:
_, _, h, w = F.shape(high_feature)
output = P.ResizeBilinear((h * 2, w * 2))(high_feature)
output = self.up_conv(output)
for feature in low_feature:
output = self.concat((output, feature))
return self.conv(output)

View File

@ -15,6 +15,7 @@
import time
import numpy as np
from PIL import Image
from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor
@ -53,3 +54,6 @@ class StepLossTimeMonitor(Callback):
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
# TEST
print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True)
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))

View File

@ -26,7 +26,8 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.unet import UNet
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.data_loader import create_dataset
from src.loss import CrossEntropyWithLogits
from src.utils import StepLossTimeMonitor
@ -53,14 +54,23 @@ def train_net(data_dir,
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size,
gradients_mean=False)
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
raise ValueError("Unsupported model: {}".format(cfg['model']))
if cfg['resume']:
param_dict = load_checkpoint(cfg['resume_ckpt'])
load_param_into_net(net, param_dict)
criterion = CrossEntropyWithLogits()
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"],
cfg['img_size'])
train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,