!46684 fix pad input dims less than padding in non constant mode

Merge pull request !46684 from 冯一航/fix_pad_inputdim_less_than_padding
This commit is contained in:
i-robot 2022-12-13 09:43:17 +00:00 committed by Gitee
commit a2ca59ea50
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 8 additions and 0 deletions

View File

@ -28,6 +28,8 @@ mindspore/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_dee
mindspore/mindspore/ccsrc/pipeline/jit/resource.cc:mindspore::pipeline::GetMethodMap mindspore/mindspore/ccsrc/pipeline/jit/resource.cc:mindspore::pipeline::GetMethodMap
mindspore/mindspore/python/mindspore/ops/operations/array_ops.py:_compute_slicing_shape mindspore/mindspore/python/mindspore/ops/operations/array_ops.py:_compute_slicing_shape
mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool3d
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
mindspore/mindspore/python/mindspore/common/tensor.py:__init__ mindspore/mindspore/python/mindspore/common/tensor.py:__init__
mindspore/mindspore/python/mindspore/common/parameter.py:set_data mindspore/mindspore/python/mindspore/common/parameter.py:set_data

View File

@ -2490,6 +2490,7 @@ def pad(input_x, padding, mode='constant', value=None):
if not isinstance(padding, Tensor): if not isinstance(padding, Tensor):
_check_pad_inputs(padding) _check_pad_inputs(padding)
padding = Tensor(padding) padding = Tensor(padding)
is_expand = False
if mode == "constant": if mode == "constant":
value = 0 if value is None else value value = 0 if value is None else value
if isinstance(value, (float, int)): if isinstance(value, (float, int)):
@ -2499,7 +2500,12 @@ def pad(input_x, padding, mode='constant', value=None):
raise ValueError(f"For 'pad', the padding mode '{mode}' can not set value, but got value {value}.") raise ValueError(f"For 'pad', the padding mode '{mode}' can not set value, but got value {value}.")
if mode == "replicate": if mode == "replicate":
mode = "edge" mode = "edge"
if padding.shape[0] == input_x.ndim + 1:
input_x = input_x.expand_dims(0)
is_expand = True
out = PadV3(mode=mode, paddings_contiguous=True)(input_x, padding, value) out = PadV3(mode=mode, paddings_contiguous=True)(input_x, padding, value)
if is_expand:
out = out.squeeze(0)
return out return out