forked from mindspore-Ecosystem/mindspore
amend some script
This commit is contained in:
parent
304211e2b6
commit
9cf856dd25
|
@ -85,12 +85,12 @@ def get_bprop_conv3d(self):
|
|||
@bprop_getters.register(nps.Conv3DTranspose)
|
||||
def get_bprop_conv3d_transpose(self):
|
||||
"""Grad definition for `Conv3DTranspose` operation."""
|
||||
filter_grad = G.Conv3DBackpropFilter(
|
||||
out_channel=self.out_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode=self.pad_mode,
|
||||
input_grad = nps.Conv3D(
|
||||
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
|
||||
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
||||
)
|
||||
input_grad = nps.Conv3D(
|
||||
out_channel=self.out_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode=self.pad_mode,
|
||||
filter_grad = G.Conv3DBackpropFilter(
|
||||
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
|
||||
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
||||
)
|
||||
input_size = self.input_size
|
||||
|
|
|
@ -371,7 +371,7 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
|
|||
self.add_prim_attr('groups', self.group)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.add_prim_attr('io_format', "NCDHW")
|
||||
self.add_prim_attr('io_format', self.format)
|
||||
|
||||
def __infer__(self, x, doutput, w_size):
|
||||
w_size_v = w_size['value']
|
||||
|
|
|
@ -6831,7 +6831,7 @@ class Conv3D(PrimitiveWithInfer):
|
|||
self.add_prim_attr('mode', self.mode)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.add_prim_attr('io_format', "NCDHW")
|
||||
self.add_prim_attr('io_format', self.format)
|
||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||
self.group = validator.check_positive_int(group, 'group', self.name)
|
||||
self.add_prim_attr('groups', self.group)
|
||||
|
@ -6988,7 +6988,7 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
|
|||
self.add_prim_attr('groups', self.group)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.add_prim_attr('io_format', "NCDHW")
|
||||
self.add_prim_attr('io_format', self.format)
|
||||
|
||||
def __infer__(self, w, doutput, x_size):
|
||||
x_size_v = x_size['value']
|
||||
|
@ -7044,13 +7044,10 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
Computes the gradients of convolution 3D with respect to the input.
|
||||
|
||||
Args:
|
||||
input_size (tuple[int]): The shape of the output with five integers. If input_ Size is set to (0, 0, 0, 0, 0),
|
||||
it will activate output_padding function. Otherwise, the output shape will be the same as the input_size,
|
||||
and the output_padding setting will be invalid, and pad_mode cannot set as 'same'. Default: (0, 0, 0, 0, 0).
|
||||
out_channel (int): The dimension of the output.
|
||||
in_channel (int): The channel of the input x.
|
||||
out_channel (int): The channel of the weight x.
|
||||
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
|
||||
mode (int): Modes for different convolutions. Not currently used.
|
||||
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
|
||||
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
||||
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers,
|
||||
the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3], pad[4]
|
||||
|
@ -7078,7 +7075,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
Examples:
|
||||
>>> input_x = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32)
|
||||
>>> weight = Tensor(np.ones([16, 3, 4, 6, 2]), mindspore.float32)
|
||||
>>> conv3d_transpose = P.Conv3DTranspose(out_channel=4, kernel_size=(4, 6, 2))
|
||||
>>> conv3d_transpose = P.Conv3DTranspose(in_channel=16, out_channel=3, kernel_size=(4, 6, 2))
|
||||
>>> output = conv3d_transpose(input_x, weight)
|
||||
>>> print(output.shape)
|
||||
(32, 3, 13, 37, 33)
|
||||
|
@ -7086,11 +7083,10 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
input_size=(0, 0, 0, 0, 0),
|
||||
mode=1,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
|
@ -7098,13 +7094,11 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
output_padding=0,
|
||||
data_format="NCDHW"):
|
||||
"""Initialize Conv3DTranspose"""
|
||||
self.init_prim_io_names(inputs=['x', 'filter', 'input_size'], outputs=['output'])
|
||||
self.input_size = validator.check_value_type('input_size', input_size, [tuple], self.name)
|
||||
validator.check_equal_int(len(self.input_size), 5, 'input_size', self.name)
|
||||
for i, dim_len in enumerate(self.input_size):
|
||||
validator.check_value_type("input_size[%d]" % i, dim_len, [int], self.name)
|
||||
self.add_prim_attr('input_size', self.input_size)
|
||||
self.init_prim_io_names(inputs=['x', 'filter'], outputs=['output'])
|
||||
self.in_channel = validator.check_positive_int(in_channel, 'in_channel', self.name)
|
||||
self.add_prim_attr('in_channel', self.in_channel)
|
||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||
self.add_prim_attr('out_channel', self.out_channel)
|
||||
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
|
||||
self.add_prim_attr('strides', self.stride)
|
||||
|
@ -7115,88 +7109,52 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
pad = (pad,) * 6
|
||||
validator.check_equal_int(len(pad), 6, 'pad size', self.name)
|
||||
self.pad_list = pad
|
||||
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
for item in self.pad_list:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
||||
self.add_prim_attr('mode', self.mode)
|
||||
self.group = validator.check_positive_int(group, 'group', self.name)
|
||||
self.add_prim_attr('groups', self.group)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.add_prim_attr('io_format', "NCDHW")
|
||||
self.add_prim_attr('io_format', self.format)
|
||||
|
||||
self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name,
|
||||
allow_five=True, ret_five=True, greater_zero=False)
|
||||
if self.output_padding[2] < 0 or (self.output_padding[2] >= self.dilation[2]
|
||||
and self.output_padding[2] >= self.stride[2]):
|
||||
raise ValueError("In op, the value of [output_padding D] should be [[0, max(stride D,dilation D))], "
|
||||
"but it is {}.".format(self.output_padding[2]))
|
||||
if self.output_padding[3] < 0 or (self.output_padding[3] >= self.dilation[3]
|
||||
and self.output_padding[3] >= self.stride[3]):
|
||||
raise ValueError("In op, the value of [output_padding H] should be [[0, max(stride H,dilation H))], "
|
||||
"but it is {}.".format(self.output_padding[3]))
|
||||
if self.output_padding[4] < 0 or (self.output_padding[4] >= self.dilation[4]
|
||||
and self.output_padding[4] >= self.stride[4]):
|
||||
raise ValueError("In op, the value of [output_padding W] should be [[0, max(stride W,dilation W))], "
|
||||
"but it is {}.".format(self.output_padding[4]))
|
||||
self.add_prim_attr('output_padding', self.output_padding)
|
||||
|
||||
def __infer__(self, x, w, b=None):
|
||||
args = {'x': x['dtype'], 'w': w['dtype']}
|
||||
if b is not None:
|
||||
args = {'x': x['dtype'], 'w': w['dtype'], 'b': b['dtype']}
|
||||
raise ValueError("Bias currently only support None.")
|
||||
valid_dtypes = [mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
|
||||
output_shape = self.input_size
|
||||
validator.check("filter's batch", w['shape'][0], "input x's channel", x['shape'][1], Rel.EQ, self.name)
|
||||
# infer shape
|
||||
x_shape = x['shape']
|
||||
kernel_d = self.kernel_size[0]
|
||||
kernel_h = self.kernel_size[1]
|
||||
kernel_w = self.kernel_size[2]
|
||||
stride_d = self.stride[2]
|
||||
stride_h = self.stride[3]
|
||||
stride_w = self.stride[4]
|
||||
dilation_d = self.dilation[2]
|
||||
dilation_h = self.dilation[3]
|
||||
dilation_w = self.dilation[4]
|
||||
if self.input_size != (0, 0, 0, 0, 0):
|
||||
# The pad_mode is valid by default. If pad_mode is not valid or same, then pad.
|
||||
if self.pad_mode == "valid":
|
||||
self.pad_list = (0, 0, 0, 0, 0, 0)
|
||||
if self.pad_mode == "same":
|
||||
pad_needed_d = max(0, (x_shape[2] - 1) * stride_d + dilation_d *
|
||||
(kernel_d - 1) + 1 - self.input_size[2])
|
||||
pad_head = math.floor(pad_needed_d / 2)
|
||||
pad_tail = pad_needed_d - pad_head
|
||||
|
||||
pad_needed_h = max(0, (x_shape[3] - 1) * stride_h + dilation_h *
|
||||
(kernel_h - 1) + 1 - self.input_size[3])
|
||||
pad_top = math.floor(pad_needed_h / 2)
|
||||
pad_bottom = pad_needed_h - pad_top
|
||||
|
||||
pad_needed_w = max(0, (x_shape[4] - 1) * stride_w + dilation_w *
|
||||
(kernel_w - 1) + 1 - self.input_size[4])
|
||||
pad_left = math.floor(pad_needed_w / 2)
|
||||
pad_right = pad_needed_w - pad_left
|
||||
self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
|
||||
|
||||
self.add_prim_attr('pads', self.pad_list)
|
||||
else:
|
||||
self.add_prim_attr('pads', self.pad_list)
|
||||
if self.pad_mode == 'same':
|
||||
raise ValueError("When input_size is (0, 0, 0, 0, 0), the pad_mode cannot be 'same'!")
|
||||
pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.pad_list
|
||||
w_shape = w['shape']
|
||||
self.output_padding = self.output_padding if self.format == "NCDHW" else \
|
||||
(self.output_padding[0], self.output_padding[4], self.output_padding[1],
|
||||
self.output_padding[2], self.output_padding[3])
|
||||
d_out = (x_shape[2] - 1) * stride_d - 2 * (pad_head + pad_tail) + dilation_d * \
|
||||
(kernel_d - 1) + self.output_padding[2] + 1
|
||||
h_out = (x_shape[3] - 1) * stride_h - 2 * (pad_top + pad_bottom) + dilation_h * \
|
||||
(kernel_h - 1) + self.output_padding[3] + 1
|
||||
w_out = (x_shape[4] - 1) * stride_w - 2 * (pad_left + pad_right) + dilation_w * \
|
||||
(kernel_w - 1) + self.output_padding[4] + 1
|
||||
output_shape = (x_shape[0], w_shape[1], d_out, h_out, w_out)
|
||||
self.add_prim_attr('input_size', output_shape)
|
||||
validator.check("filter's channel", w['shape'][1], "input_size's channel", output_shape[1], Rel.EQ, self.name)
|
||||
validator.check("filter's batch", w['shape'][0], "input x's channel", x['shape'][1], Rel.EQ, self.name)
|
||||
validator.check("input_size's batch", output_shape[0], "x's batch", x['shape'][0], Rel.EQ, self.name)
|
||||
w_shape = w['shape']
|
||||
self.add_prim_attr('pads', self.pad_list)
|
||||
pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.pad_list
|
||||
d_out = (x_shape[2] - 1) * self.stride[2] - (pad_head + pad_tail) + self.dilation[2] * \
|
||||
(self.kernel_size[0] - 1) + self.output_padding[2] + 1
|
||||
h_out = (x_shape[3] - 1) * self.stride[3] - (pad_top + pad_bottom) + self.dilation[3] * \
|
||||
(self.kernel_size[1] - 1) + self.output_padding[3] + 1
|
||||
w_out = (x_shape[4] - 1) * self.stride[4] - (pad_left + pad_right) + self.dilation[4] * \
|
||||
(self.kernel_size[2] - 1) + self.output_padding[4] + 1
|
||||
output_shape = (x_shape[0], w_shape[1], d_out, h_out, w_out)
|
||||
self.add_prim_attr('input_size', output_shape)
|
||||
out = {
|
||||
'value': None,
|
||||
'shape': output_shape,
|
||||
|
|
|
@ -39,4 +39,4 @@ if __name__ == '__main__':
|
|||
# load the parameter into net
|
||||
load_param_into_net(network, param_dict)
|
||||
input_data = np.random.uniform(0.0, 1.0, size=[32, 3, 513, 513]).astype(np.float32)
|
||||
export(network, Tensor(input_data), file_name=args.model+'-300_11', file_format='AIR')
|
||||
export(network, Tensor(input_data), file_name=args.model, file_format='AIR')
|
||||
|
|
|
@ -136,7 +136,7 @@ class ASPP(nn.Cell):
|
|||
self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics)
|
||||
self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics)
|
||||
self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics)
|
||||
self.aspp_pooling = ASPPPooling(in_channels, out_channels)
|
||||
self.aspp_pooling = ASPPPooling(in_channels, out_channels, use_batch_statistics=use_batch_statistics)
|
||||
self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
|
||||
|
|
|
@ -472,11 +472,11 @@ class Conv3DBackpropFilter(nn.Cell):
|
|||
class Conv3DTranspose(nn.Cell):
|
||||
"""Conv3DTranspose net definition"""
|
||||
|
||||
def __init__(self, out_channel, kernel_size, input_size, mode, pad_mode, pad, stride, dilation, group, data_format):
|
||||
def __init__(self, in_channel, out_channel, kernel_size, mode, pad, stride, dilation, group, data_format):
|
||||
super(Conv3DTranspose, self).__init__()
|
||||
self.conv = nps.Conv3DTranspose(out_channel=out_channel, kernel_size=kernel_size, input_size=input_size,
|
||||
mode=mode, pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation,
|
||||
group=group, data_format=data_format)
|
||||
self.conv = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel, kernel_size=kernel_size,
|
||||
mode=mode, pad=pad, stride=stride, dilation=dilation, group=group,
|
||||
data_format=data_format)
|
||||
|
||||
def construct(self, x, w):
|
||||
ms_out = self.conv(x, w)
|
||||
|
@ -1259,8 +1259,8 @@ test_case_math_ops = [
|
|||
Tensor(np.random.random((16, 32, 10, 32, 32)).astype(np.float16))],
|
||||
'skip': ['backward']}),
|
||||
('Conv3DTranspose', {
|
||||
'block': Conv3DTranspose(out_channel=3, kernel_size=(4, 6, 2), input_size=(0, 0, 0, 0, 0), mode=1,
|
||||
pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"),
|
||||
'block': Conv3DTranspose(in_channel=32, out_channel=3, kernel_size=(4, 6, 2), mode=1,
|
||||
pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"),
|
||||
'desc_inputs': [Tensor(np.random.random((32, 3, 10, 32, 32)).astype(np.float16)),
|
||||
Tensor(np.random.random((3, 3, 4, 6, 2)).astype(np.float16))],
|
||||
'skip': ['backward']}),
|
||||
|
|
Loading…
Reference in New Issue