forked from mindspore-Ecosystem/mindspore
!14410 modify conv3d and conv3dtranspose for unet3d
From: @Somnus2020 Reviewed-by: @liangchenghui,@c_34 Signed-off-by: @liangchenghui
This commit is contained in:
commit
ff1f27138f
|
@ -102,7 +102,6 @@ python eval.py --data_url=/path/to/data/ --seg_url=/path/to/segment/ --ckpt_path
|
||||||
│ ├──transform.py // handle dataset
|
│ ├──transform.py // handle dataset
|
||||||
│ ├──convert_nifti.py // convert dataset
|
│ ├──convert_nifti.py // convert dataset
|
||||||
│ ├──loss.py // loss
|
│ ├──loss.py // loss
|
||||||
│ ├──conv.py // conv components
|
|
||||||
│ ├──utils.py // General components (callback function)
|
│ ├──utils.py // General components (callback function)
|
||||||
│ ├──unet3d_model.py // Unet3D model
|
│ ├──unet3d_model.py // Unet3D model
|
||||||
│ ├──unet3d_parts.py // Unet3D part
|
│ ├──unet3d_parts.py // Unet3D part
|
||||||
|
|
|
@ -1,90 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# less required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore import Parameter
|
|
||||||
from mindspore import dtype as mstype
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
from mindspore.ops.operations import nn_ops as nps
|
|
||||||
from mindspore.common.initializer import initializer
|
|
||||||
|
|
||||||
def weight_variable(shape):
|
|
||||||
init_value = initializer('Normal', shape, mstype.float32)
|
|
||||||
return Parameter(init_value)
|
|
||||||
|
|
||||||
class Conv3D(nn.Cell):
|
|
||||||
def __init__(self,
|
|
||||||
in_channel,
|
|
||||||
out_channel,
|
|
||||||
kernel_size,
|
|
||||||
mode=1,
|
|
||||||
pad_mode="valid",
|
|
||||||
pad=0,
|
|
||||||
stride=1,
|
|
||||||
dilation=1,
|
|
||||||
group=1,
|
|
||||||
data_format="NCDHW",
|
|
||||||
bias_init="zeros",
|
|
||||||
has_bias=True):
|
|
||||||
super().__init__()
|
|
||||||
self.weight_shape = (out_channel, in_channel, kernel_size[0], kernel_size[1], kernel_size[2])
|
|
||||||
self.weight = weight_variable(self.weight_shape)
|
|
||||||
self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, \
|
|
||||||
pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation, \
|
|
||||||
group=group, data_format=data_format)
|
|
||||||
self.bias_init = bias_init
|
|
||||||
self.has_bias = has_bias
|
|
||||||
self.bias_add = P.BiasAdd(data_format=data_format)
|
|
||||||
if self.has_bias:
|
|
||||||
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
output = self.conv(x, self.weight)
|
|
||||||
if self.has_bias:
|
|
||||||
output = self.bias_add(output, self.bias)
|
|
||||||
return output
|
|
||||||
|
|
||||||
class Conv3DTranspose(nn.Cell):
|
|
||||||
def __init__(self,
|
|
||||||
in_channel,
|
|
||||||
out_channel,
|
|
||||||
kernel_size,
|
|
||||||
mode=1,
|
|
||||||
pad=0,
|
|
||||||
stride=1,
|
|
||||||
dilation=1,
|
|
||||||
group=1,
|
|
||||||
output_padding=0,
|
|
||||||
data_format="NCDHW",
|
|
||||||
bias_init="zeros",
|
|
||||||
has_bias=True):
|
|
||||||
super().__init__()
|
|
||||||
self.weight_shape = (in_channel, out_channel, kernel_size[0], kernel_size[1], kernel_size[2])
|
|
||||||
self.weight = weight_variable(self.weight_shape)
|
|
||||||
self.conv_transpose = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel,\
|
|
||||||
kernel_size=kernel_size, mode=mode, pad=pad, stride=stride, \
|
|
||||||
dilation=dilation, group=group, output_padding=output_padding, \
|
|
||||||
data_format=data_format)
|
|
||||||
self.bias_init = bias_init
|
|
||||||
self.has_bias = has_bias
|
|
||||||
self.bias_add = P.BiasAdd(data_format=data_format)
|
|
||||||
if self.has_bias:
|
|
||||||
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
output = self.conv_transpose(x, self.weight)
|
|
||||||
if self.has_bias:
|
|
||||||
output = self.bias_add(output, self.bias)
|
|
||||||
return output
|
|
|
@ -16,7 +16,6 @@
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import dtype as mstype
|
from mindspore import dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from src.conv import Conv3D, Conv3DTranspose
|
|
||||||
|
|
||||||
class BatchNorm3d(nn.Cell):
|
class BatchNorm3d(nn.Cell):
|
||||||
def __init__(self, num_features):
|
def __init__(self, num_features):
|
||||||
|
@ -39,22 +38,22 @@ class ResidualUnit(nn.Cell):
|
||||||
self.down = down
|
self.down = down
|
||||||
self.in_channel = in_channel
|
self.in_channel = in_channel
|
||||||
self.out_channel = out_channel
|
self.out_channel = out_channel
|
||||||
self.down_conv_1 = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
|
self.down_conv_1 = nn.Conv3d(in_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||||
pad_mode="pad", stride=self.stride, pad=1)
|
pad_mode="pad", stride=self.stride, padding=1)
|
||||||
self.is_output = is_output
|
self.is_output = is_output
|
||||||
if not is_output:
|
if not is_output:
|
||||||
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
|
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
|
||||||
self.relu1 = nn.PReLU()
|
self.relu1 = nn.PReLU()
|
||||||
if self.down:
|
if self.down:
|
||||||
self.down_conv_2 = Conv3D(out_channel, out_channel, kernel_size=(3, 3, 3), \
|
self.down_conv_2 = nn.Conv3d(out_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||||
pad_mode="pad", stride=1, pad=1)
|
pad_mode="pad", stride=1, padding=1)
|
||||||
self.relu2 = nn.PReLU()
|
self.relu2 = nn.PReLU()
|
||||||
if kernel_size[0] == 1:
|
if kernel_size[0] == 1:
|
||||||
self.residual = Conv3D(in_channel, out_channel, kernel_size=(1, 1, 1), \
|
self.residual = nn.Conv3d(in_channel, out_channel, kernel_size=(1, 1, 1), \
|
||||||
pad_mode="valid", stride=self.stride)
|
pad_mode="valid", stride=self.stride)
|
||||||
else:
|
else:
|
||||||
self.residual = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
|
self.residual = nn.Conv3d(in_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||||
pad_mode="pad", stride=self.stride, pad=1)
|
pad_mode="pad", stride=self.stride, padding=1)
|
||||||
self.batchNormal2 = BatchNorm3d(num_features=self.out_channel)
|
self.batchNormal2 = BatchNorm3d(num_features=self.out_channel)
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,9 +92,10 @@ class Up(nn.Cell):
|
||||||
self.down_in_channel = down_in_channel
|
self.down_in_channel = down_in_channel
|
||||||
self.out_channel = out_channel
|
self.out_channel = out_channel
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.conv3d_transpose = Conv3DTranspose(in_channel=self.in_channel + self.down_in_channel, \
|
self.conv3d_transpose = nn.Conv3dTranspose(in_channels=self.in_channel + self.down_in_channel, \
|
||||||
pad=1, out_channel=self.out_channel, kernel_size=(3, 3, 3), \
|
out_channels=self.out_channel, kernel_size=(3, 3, 3), \
|
||||||
stride=self.stride, output_padding=(1, 1, 1))
|
pad_mode="pad", stride=self.stride, \
|
||||||
|
output_padding=(1, 1, 1), padding=1)
|
||||||
|
|
||||||
self.concat = P.Concat(axis=1)
|
self.concat = P.Concat(axis=1)
|
||||||
self.conv = ResidualUnit(self.out_channel, self.out_channel, stride=1, down=False, \
|
self.conv = ResidualUnit(self.out_channel, self.out_channel, stride=1, down=False, \
|
||||||
|
|
Loading…
Reference in New Issue