fix pad input dims less than padding

This commit is contained in:
fengyihang 2022-12-12 11:13:33 +08:00
parent e01fafa85e
commit 381464dd35
2 changed files with 8 additions and 0 deletions

View File

@ -28,6 +28,8 @@ mindspore/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_dee
mindspore/mindspore/ccsrc/pipeline/jit/resource.cc:mindspore::pipeline::GetMethodMap
mindspore/mindspore/python/mindspore/ops/operations/array_ops.py:_compute_slicing_shape
mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool3d
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
mindspore/mindspore/python/mindspore/common/parameter.py:set_data

View File

@ -2490,6 +2490,7 @@ def pad(input_x, padding, mode='constant', value=None):
if not isinstance(padding, Tensor):
_check_pad_inputs(padding)
padding = Tensor(padding)
is_expand = False
if mode == "constant":
value = 0 if value is None else value
if isinstance(value, (float, int)):
@ -2499,7 +2500,12 @@ def pad(input_x, padding, mode='constant', value=None):
raise ValueError(f"For 'pad', the padding mode '{mode}' can not set value, but got value {value}.")
if mode == "replicate":
mode = "edge"
if padding.shape[0] == input_x.ndim + 1:
input_x = input_x.expand_dims(0)
is_expand = True
out = PadV3(mode=mode, paddings_contiguous=True)(input_x, padding, value)
if is_expand:
out = out.squeeze(0)
return out