forked from mindspore-Ecosystem/mindspore
!13123 fix the output shape of the operator maxPoolGradGrad
From: @david-he91 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
89969f3502
|
@ -976,12 +976,12 @@ class MaxPoolGradGrad(_PoolGrad):
|
|||
super(MaxPoolGradGrad, self).__init__(kernel_size, strides, pad_mode)
|
||||
|
||||
def infer_shape(self, x1_shape, x2_shape, grad_shape):
|
||||
return x1_shape
|
||||
return x2_shape
|
||||
|
||||
def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
|
||||
args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
|
||||
return x1_dtype
|
||||
return x2_dtype
|
||||
|
||||
|
||||
def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode):
|
||||
|
|
|
@ -3376,13 +3376,13 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|||
`indices`, with values from `update`. This operation is almost equivalent to using
|
||||
ScatterNd, except that the updates are applied on `input_x` instead of a zero tensor.
|
||||
|
||||
`indices` must have rank atleast 2, the last axis is the depth of each index
|
||||
`indices` must have rank at least 2, the last axis is the depth of each index
|
||||
vectors. For each index vector, there must be a corresponding value in `update`. If
|
||||
the depth of each index tensor matches the rank of `input_x`, then each index
|
||||
vector corresponds to a scalar in `input_x` and each update updates a scalar. If
|
||||
the depth of each index tensor is less than the rnak of `input_x`, then each index
|
||||
vector corresponds to a slice in `input_x`, and each update updates a slice.
|
||||
|
||||
|
||||
The order in which updates are applied is nondeterministic, meaning that if there
|
||||
are multiple index vectors in `indices` that correspond to the same position, the
|
||||
value of that position in the output will be nondeterministic.
|
||||
|
@ -3390,7 +3390,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
|
||||
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
||||
The rank must be atleast 2.
|
||||
The rank must be at least 2.
|
||||
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
||||
|
||||
|
@ -3520,7 +3520,7 @@ class ScatterNdUpdate(_ScatterNdOp):
|
|||
- **indices** (Tensor) - The index of input tensor, with int32 data type.
|
||||
The rank of indices must be at least 2 and `indices_shape[-1] <= len(shape)`.
|
||||
- **updates** (Tensor) - The tensor to be updated to the input tensor, has the same type as input.
|
||||
the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
|
||||
The shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
|
Loading…
Reference in New Issue