support more input

This commit is contained in:
fengyihang 2022-10-18 11:25:57 +08:00
parent 72f9ce3939
commit 4906616b89
3 changed files with 30 additions and 13 deletions

View File

@ -26,10 +26,10 @@ mindspore.nn.AvgPool3d
- **divisor_override** (int) - 如果被指定为非0参数该参数将会在平均计算中被用作除数否则将会使用 `kernel_size` 作为除数默认值None。
输入:
- **x** (Tensor) - shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型必须为float16或者float32。
- **x** (Tensor) - shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})` 或者 :math:`(C, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型必须为float16或者float32。
输出:
shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})` 的Tensor。数据类型与 `x` 一致。
shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})` 或者 :math:`(C, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型与 `x` 一致。
异常:
- **TypeError** - `kernel_size` `stride``padding` 既不是整数也不是元组。

View File

@ -23,13 +23,13 @@ mindspore.nn.MaxPool3d
- **ceil_mode** (bool) - 若为True使用ceil来计算输出shape。若为False使用floor来计算输出shape。默认值False。
输入:
- **x** (Tensor) - shape为 :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型必须为int8、 int16、 int32、 int64、 uint8、 uint16、 uint32、 uint64、 float16、 float32 或者 float64。
- **x** (Tensor) - shape为 :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})` 或者 :math:`(C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型必须为int8、 int16、 int32、 int64、 uint8、 uint16、 uint32、 uint64、 float16、 float32 或者 float64。
输出:
如果 `return_indices` 为False则是shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})` 的Tensor。数据类型与 `x` 一致。
如果 `return_indices` 为False则是shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})` 或者 :math:`(C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型与 `x` 一致。
如果 `return_indices` 为True则是一个包含了两个Tensor的Tuple表示maxpool的计算结果以及生成max值的位置。
- **output** (Tensor) - 最大池化结果shape为 :math:`(N_{out}, C_{out}, D_{out}, H_{out}, W_{out})`的Tensor。数据类型与 `x` 一致。
- **output** (Tensor) - 最大池化结果shape为 :math:`(N_{out}, C_{out}, D_{out}, H_{out}, W_{out})` 或者 :math:`(C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型与 `x` 一致。
- **argmax** (Tensor) - 最大值对应的索引。数据类型为int64。
异常:

View File

@ -108,18 +108,19 @@ class MaxPool3d(Cell):
ceil_mode (bool): Whether to use ceil or floor to calculate output shape. Default: False.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})` with data type of int8,
int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32 or float64.
- **x** (Tensor) - Tensor of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})` or
:math:`(C_{in}, D_{in}, H_{in}, W_{in})` with data type of int8, int16, int32,
int64, uint8, uint16, uint32, uint64, float16, float32 or float64.
Outputs:
If `return_indices` is False, output is a Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`,
with the same data type as `x`.
If `return_indices` is False, output is a Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`, or
:math:`(C_{out}, D_{out}, H_{out}, W_{out})`. It has the same data type as `x`.
If `return_indices` is True, output is a Tuple of 2 Tensors, representing the maxpool result and where
the max values are generated.
- **output** (Tensor) - Maxpooling result, with shape :math:`(N_{out}, C_{out}, D_{out}, H_{out}, W_{out})`.
It has the same data type as `x`.
- **output** (Tensor) - Maxpooling result, with shape :math:`(N_{out}, C_{out}, D_{out}, H_{out}, W_{out})` or
:math:`(C_{out}, D_{out}, H_{out}, W_{out})`. It has the same data type as `x`.
- **argmax** (Tensor) - Index corresponding to the maximum value. Data type is int64.
Raises:
@ -175,9 +176,15 @@ class MaxPool3d(Cell):
super(MaxPool3d, self).__init__()
self.return_indices = return_indices
self.max_pool = MaxPool3DWithArgmax(kernel_size, stride, padding, dilation, ceil_mode)
self.expand_dims = P.ExpandDims()
def construct(self, x):
_shape = x.shape
if len(x.shape) == 4:
x = self.expand_dims(x, 0)
output_tensor, argmax = self.max_pool(x)
output_tensor = output_tensor.reshape(_shape)
argmax = argmax.reshape(_shape)
if self.return_indices:
return output_tensor, argmax
return output_tensor
@ -375,11 +382,13 @@ class AvgPool3d(Cell):
otherwise kernel_size will be used. Default: None.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`.
- **x** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})` or
:math:`(C, D_{in}, H_{in}, W_{in})`.
Currently support float16 and float32 data type.
Outputs:
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`, with the same data type with `x`.
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})` or
:math:`(C, D_{in}, H_{in}, W_{in})`, with the same data type with `x`.
Raises:
TypeError: If `kernel_size`, `stride` or `padding` is neither an int nor a tuple.
@ -415,9 +424,17 @@ class AvgPool3d(Cell):
divisor_override = 0
self.avg_pool = P.AvgPool3D(kernel_size, stride, "pad", padding, ceil_mode, count_include_pad,
divisor_override)
self.squeeze = P.Squeeze(0)
self.expand_dims = P.ExpandDims()
def construct(self, x):
_is_squeeze = False
if len(x.shape) == 4:
x = self.expand_dims(x, 0)
_is_squeeze = True
out = self.avg_pool(x)
if _is_squeeze:
out = self.squeeze(out)
return out