add api image gradients

This commit is contained in:
zhaozhenlong 2020-04-03 15:37:42 +08:00
parent 475e858474
commit f9d180d413
4 changed files with 158 additions and 2 deletions

View File

@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
from .container import SequentialCell, CellList
from .conv import Conv2d, Conv2dTranspose
from .lstm import LSTM
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradients
from .embedding import Embedding
from .pooling import AvgPool2d, MaxPool2d
@ -31,7 +31,7 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'SequentialCell', 'CellList',
'Conv2d', 'Conv2dTranspose',
'LSTM',
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot',
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'ImageGradients',
'Embedding',
'AvgPool2d', 'MaxPool2d',
]

View File

@ -370,3 +370,48 @@ class OneHot(Cell):
def construct(self, indices):
return self.onehot(indices, self.depth, self.on_value, self.off_value)
class ImageGradients(Cell):
r"""
Returns two tensors, the first is along the height dimension and the second is along the width dimension.
Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`,
respectively.
.. math::
dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr
0, &if\ i==h-1\end{cases}
dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr
0, &if\ i==w-1\end{cases}
Inputs:
- **images** (Tensor) - The input image data, with format 'NCHW'.
Outputs:
- **dy** (Tensor) - vertical image gradients, the same type and shape as input.
- **dx** (Tensor) - horizontal image gradients, the same type and shape as input.
Examples:
>>> net = nn.ImageGradients()
>>> image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32)
>>> net(image)
[[[[2,2]
[0,0]]]]
[[[[1,0]
[1,0]]]]
"""
def __init__(self):
super(ImageGradients, self).__init__()
def construct(self, images):
batch_size, depth, height, width = P.Shape()(images)
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
dy = P.Concat(2)((dy, dy_last))
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
dx = P.Concat(3)((dx, dx_last))
return dy, dx

View File

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.common.api import ms_function
context.set_context(device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.image_gradients = nn.ImageGradients()
@ms_function
def construct(self, x):
return self.image_gradients(x)
def test_image_gradients():
image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32)
expected_dy = np.array([[[[2,2],[0,0]]]]).astype(np.int32)
expected_dx = np.array([[[[1,0],[1,0]]]]).astype(np.int32)
net = Net()
dy, dx = net(image)
assert np.any(dx.asnumpy()-expected_dx) == False
assert np.any(dy.asnumpy()-expected_dy) == False
def test_image_gradients_multi_channel_depth():
# 4 x 2 x 2 x 2
dtype = mstype.int32
image = Tensor(np.array([[[[1,2],[3,4]], [[5,6],[7,8]]],
[[[3,5],[7,9]], [[11,13],[15,17]]],
[[[5,10],[15,20]], [[25,30],[35,40]]],
[[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype)
expected_dy = Tensor(np.array([[[[2,2],[0,0]], [[2,2],[0,0]]],
[[[4,4],[0,0]], [[4,4],[0,0]]],
[[[10,10],[0,0]], [[10,10],[0,0]]],
[[[20,20],[0,0]], [[20,20],[0,0]]]]), dtype=dtype)
expected_dx = Tensor(np.array([[[[1,0],[1,0]], [[1,0],[1,0]]],
[[[2,0],[2,0]], [[2,0],[2,0]]],
[[[5,0],[5,0]], [[5,0],[5,0]]],
[[[10,0],[10,0]], [[10,0],[10,0]]]]), dtype=dtype)
net = Net()
dy, dx = net(image)
assert np.any(dx.asnumpy()-expected_dx.asnumpy()) == False
assert np.any(dy.asnumpy()-expected_dy.asnumpy()) == False

View File

@ -0,0 +1,49 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test loss """
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.common.api import _executor
from mindspore.common.api import ms_function
context.set_context(device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.image_gradients = nn.ImageGradients()
@ms_function
def construct(self, x):
return self.image_gradients(x)
def test_compile():
# input shape 1 x 1 x 2 x 2
image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32)
net = Net()
_executor.compile(net, image)
def test_compile_multi_channel():
# input shape 4 x 2 x 2 x 2
dtype = mstype.int32
image = Tensor(np.array([[[[1,2],[3,4]], [[5,6],[7,8]]],
[[[3,5],[7,9]], [[11,13],[15,17]]],
[[[5,10],[15,20]], [[25,30],[35,40]]],
[[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype)
net = Net()
_executor.compile(net, image)