forked from mindspore-Ecosystem/mindspore
fix image gradient height 1
This commit is contained in:
parent
e4c992aabb
commit
b790cc1653
|
@ -66,13 +66,19 @@ class ImageGradients(Cell):
|
|||
check = _check_input_4d(F.shape(images), "images", self.cls_name)
|
||||
images = F.depend(images, check)
|
||||
batch_size, depth, height, width = P.Shape()(images)
|
||||
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
|
||||
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
|
||||
dy = P.Concat(2)((dy, dy_last))
|
||||
if height == 1:
|
||||
dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
|
||||
else:
|
||||
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
|
||||
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
|
||||
dy = P.Concat(2)((dy, dy_last))
|
||||
|
||||
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
|
||||
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
|
||||
dx = P.Concat(3)((dx, dx_last))
|
||||
if width == 1:
|
||||
dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
|
||||
else:
|
||||
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
|
||||
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
|
||||
dx = P.Concat(3)((dx, dx_last))
|
||||
return dy, dx
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue