support more input
This commit is contained in:
parent
72f9ce3939
commit
4906616b89
|
@ -26,10 +26,10 @@ mindspore.nn.AvgPool3d
|
|||
- **divisor_override** (int) - 如果被指定为非0参数,该参数将会在平均计算中被用作除数,否则将会使用 `kernel_size` 作为除数,默认值:None。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型必须为float16或者float32。
|
||||
- **x** (Tensor) - shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})` 或者 :math:`(C, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型必须为float16或者float32。
|
||||
|
||||
输出:
|
||||
shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})` 的Tensor。数据类型与 `x` 一致。
|
||||
shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})` 或者 :math:`(C, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型与 `x` 一致。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `kernel_size` ,`stride` 或 `padding` 既不是整数也不是元组。
|
||||
|
|
|
@ -23,13 +23,13 @@ mindspore.nn.MaxPool3d
|
|||
- **ceil_mode** (bool) - 若为True,使用ceil来计算输出shape。若为False,使用floor来计算输出shape。默认值:False。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - shape为 :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型必须为int8、 int16、 int32、 int64、 uint8、 uint16、 uint32、 uint64、 float16、 float32 或者 float64。
|
||||
- **x** (Tensor) - shape为 :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})` 或者 :math:`(C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型必须为int8、 int16、 int32、 int64、 uint8、 uint16、 uint32、 uint64、 float16、 float32 或者 float64。
|
||||
|
||||
输出:
|
||||
如果 `return_indices` 为False,则是shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})` 的Tensor。数据类型与 `x` 一致。
|
||||
如果 `return_indices` 为False,则是shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})` 或者 :math:`(C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型与 `x` 一致。
|
||||
如果 `return_indices` 为True,则是一个包含了两个Tensor的Tuple,表示maxpool的计算结果以及生成max值的位置。
|
||||
|
||||
- **output** (Tensor) - 最大池化结果,shape为 :math:`(N_{out}, C_{out}, D_{out}, H_{out}, W_{out})`的Tensor。数据类型与 `x` 一致。
|
||||
- **output** (Tensor) - 最大池化结果,shape为 :math:`(N_{out}, C_{out}, D_{out}, H_{out}, W_{out})` 或者 :math:`(C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。数据类型与 `x` 一致。
|
||||
- **argmax** (Tensor) - 最大值对应的索引。数据类型为int64。
|
||||
|
||||
异常:
|
||||
|
|
|
@ -108,18 +108,19 @@ class MaxPool3d(Cell):
|
|||
ceil_mode (bool): Whether to use ceil or floor to calculate output shape. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})` with data type of int8,
|
||||
int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32 or float64.
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})` or
|
||||
:math:`(C_{in}, D_{in}, H_{in}, W_{in})` with data type of int8, int16, int32,
|
||||
int64, uint8, uint16, uint32, uint64, float16, float32 or float64.
|
||||
|
||||
Outputs:
|
||||
If `return_indices` is False, output is a Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`,
|
||||
with the same data type as `x`.
|
||||
If `return_indices` is False, output is a Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`, or
|
||||
:math:`(C_{out}, D_{out}, H_{out}, W_{out})`. It has the same data type as `x`.
|
||||
|
||||
If `return_indices` is True, output is a Tuple of 2 Tensors, representing the maxpool result and where
|
||||
the max values are generated.
|
||||
|
||||
- **output** (Tensor) - Maxpooling result, with shape :math:`(N_{out}, C_{out}, D_{out}, H_{out}, W_{out})`.
|
||||
It has the same data type as `x`.
|
||||
- **output** (Tensor) - Maxpooling result, with shape :math:`(N_{out}, C_{out}, D_{out}, H_{out}, W_{out})` or
|
||||
:math:`(C_{out}, D_{out}, H_{out}, W_{out})`. It has the same data type as `x`.
|
||||
- **argmax** (Tensor) - Index corresponding to the maximum value. Data type is int64.
|
||||
|
||||
Raises:
|
||||
|
@ -175,9 +176,15 @@ class MaxPool3d(Cell):
|
|||
super(MaxPool3d, self).__init__()
|
||||
self.return_indices = return_indices
|
||||
self.max_pool = MaxPool3DWithArgmax(kernel_size, stride, padding, dilation, ceil_mode)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
|
||||
def construct(self, x):
|
||||
_shape = x.shape
|
||||
if len(x.shape) == 4:
|
||||
x = self.expand_dims(x, 0)
|
||||
output_tensor, argmax = self.max_pool(x)
|
||||
output_tensor = output_tensor.reshape(_shape)
|
||||
argmax = argmax.reshape(_shape)
|
||||
if self.return_indices:
|
||||
return output_tensor, argmax
|
||||
return output_tensor
|
||||
|
@ -375,11 +382,13 @@ class AvgPool3d(Cell):
|
|||
otherwise kernel_size will be used. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`.
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})` or
|
||||
:math:`(C, D_{in}, H_{in}, W_{in})`.
|
||||
Currently support float16 and float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`, with the same data type with `x`.
|
||||
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})` or
|
||||
:math:`(C, D_{in}, H_{in}, W_{in})`, with the same data type with `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `kernel_size`, `stride` or `padding` is neither an int nor a tuple.
|
||||
|
@ -415,9 +424,17 @@ class AvgPool3d(Cell):
|
|||
divisor_override = 0
|
||||
self.avg_pool = P.AvgPool3D(kernel_size, stride, "pad", padding, ceil_mode, count_include_pad,
|
||||
divisor_override)
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
|
||||
def construct(self, x):
|
||||
_is_squeeze = False
|
||||
if len(x.shape) == 4:
|
||||
x = self.expand_dims(x, 0)
|
||||
_is_squeeze = True
|
||||
out = self.avg_pool(x)
|
||||
if _is_squeeze:
|
||||
out = self.squeeze(out)
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue