ssim impl code

This commit is contained in:
zhaozhenlong 2020-04-09 16:58:18 +08:00
parent 9bda080bb5
commit 6a2cf4b6e6
4 changed files with 202 additions and 48 deletions

View File

@ -22,9 +22,10 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
from .container import SequentialCell, CellList
from .conv import Conv2d, Conv2dTranspose
from .lstm import LSTM
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradients, Pad
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad
from .embedding import Embedding
from .pooling import AvgPool2d, MaxPool2d
from .image import ImageGradients, SSIM
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
@ -32,7 +33,8 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'SequentialCell', 'CellList',
'Conv2d', 'Conv2dTranspose',
'LSTM',
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'ImageGradients',
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot',
'Embedding',
'AvgPool2d', 'MaxPool2d', 'Pad',
'ImageGradients', 'SSIM',
]

View File

@ -372,51 +372,6 @@ class OneHot(Cell):
return self.onehot(indices, self.depth, self.on_value, self.off_value)
class ImageGradients(Cell):
r"""
Returns two tensors, the first is along the height dimension and the second is along the width dimension.
Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`,
respectively.
.. math::
dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr
0, &if\ i==h-1\end{cases}
dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr
0, &if\ i==w-1\end{cases}
Inputs:
- **images** (Tensor) - The input image data, with format 'NCHW'.
Outputs:
- **dy** (Tensor) - vertical image gradients, the same type and shape as input.
- **dx** (Tensor) - horizontal image gradients, the same type and shape as input.
Examples:
>>> net = nn.ImageGradients()
>>> image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32)
>>> net(image)
[[[[2,2]
[0,0]]]]
[[[[1,0]
[1,0]]]]
"""
def __init__(self):
super(ImageGradients, self).__init__()
def construct(self, images):
batch_size, depth, height, width = P.Shape()(images)
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
dy = P.Concat(2)((dy, dy_last))
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
dx = P.Concat(3)((dx, dx_last))
return dy, dx
class Pad(Cell):
"""
Pads the input tensor according to the paddings and mode.

197
mindspore/nn/layer/image.py Normal file
View File

@ -0,0 +1,197 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""image"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Rel
from ..cell import Cell
class ImageGradients(Cell):
r"""
Returns two tensors, the first is along the height dimension and the second is along the width dimension.
Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`,
respectively.
.. math::
dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr
0, &if\ i==h-1\end{cases}
dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr
0, &if\ i==w-1\end{cases}
Inputs:
- **images** (Tensor) - The input image data, with format 'NCHW'.
Outputs:
- **dy** (Tensor) - vertical image gradients, the same type and shape as input.
- **dx** (Tensor) - horizontal image gradients, the same type and shape as input.
Examples:
>>> net = nn.ImageGradients()
>>> image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32)
>>> net(image)
[[[[2,2]
[0,0]]]]
[[[[1,0]
[1,0]]]]
"""
def __init__(self):
super(ImageGradients, self).__init__()
def construct(self, images):
batch_size, depth, height, width = P.Shape()(images)
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
dy = P.Concat(2)((dy, dy_last))
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
dx = P.Concat(3)((dx, dx_last))
return dy, dx
@constexpr
def _gauss_kernel_helper(filter_size):
"""gauss kernel helper"""
filter_size = F.scalar_cast(filter_size, mstype.int32)
coords = ()
for i in range(filter_size):
i_cast = F.scalar_cast(i, mstype.float32)
offset = F.scalar_cast(filter_size-1, mstype.float32)/2.0
element = i_cast-offset
coords = coords+(element,)
g = np.square(coords).astype(np.float32)
g = Tensor(g)
return filter_size, g
class SSIM(Cell):
r"""
Returns SSIM index between img1 and img2.
Its implementation is based on Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). `Image quality
assessment: from error visibility to structural similarity <https://ieeexplore.ieee.org/document/1284395>`_.
IEEE transactions on image processing.
.. math::
l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
SSIM(x,y)&=l*c*s\\&=\frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2}{(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}.
Args:
max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
Default: 1.0.
filter_size (int): The size of the Gaussian filter. Default: 11.
filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
Inputs:
- **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2.
- **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1.
Outputs:
Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.
Examples:
>>> net = nn.SSIM()
>>> img1 = Tensor(np.random.random((1,3,16,16)))
>>> img2 = Tensor(np.random.random((1,3,16,16)))
>>> ssim = net(img1, img2)
"""
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
super(SSIM, self).__init__()
validator.check_type('max_val', max_val, [int, float])
validator.check('max_val', max_val, '', 0.0, Rel.GT)
self.max_val = max_val
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE)
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma)
validator.check_type('k1', k1, [float])
self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER)
validator.check_type('k2', k2, [float])
self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER)
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2):
max_val = self._convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = self._convert_img_dtype_to_float32(img1, self.max_val)
img2 = self._convert_img_dtype_to_float32(img2, self.max_val)
kernel = self._fspecial_gauss(self.filter_size, self.filter_sigma)
kernel = P.Tile()(kernel, (1, P.Shape()(img1)[1], 1, 1))
mean_ssim = self._calculate_mean_ssim(img1, img2, kernel, max_val, self.k1, self.k2)
return mean_ssim
def _convert_img_dtype_to_float32(self, img, max_val):
"""convert img dtype to float32"""
# Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
# We will scale img pixel value if max_val > 1. and just cast otherwise.
ret = P.Cast()(img, mstype.float32)
max_val = F.scalar_cast(max_val, mstype.float32)
if max_val > 1.:
scale = 1./max_val
ret = ret*scale
return ret
def _calculate_mean_ssim(self, x, y, kernel, max_val, k1, k2):
"""calculate mean ssim"""
c1 = (k1*max_val)*(k1*max_val)
c2 = (k2*max_val)*(k2*max_val)
# SSIM luminance formula
# (2 * mean_{x} * mean_{y} + c1) / (mean_{x}**2 + mean_{y}**2 + c1)
mean_x = self.mean(x, kernel)
mean_y = self.mean(y, kernel)
square_sum = F.square(mean_x)+F.square(mean_y)
luminance = (2*mean_x*mean_y+c1)/(square_sum+c1)
# SSIM contrast*structure formula (when c3 = c2/2)
# (2 * conv_{xy} + c2) / (conv_{xx} + conv_{yy} + c2), equals to
# (2 * (mean_{xy} - mean_{x}*mean_{y}) + c2) / (mean_{xx}-mean_{x}**2 + mean_{yy}-mean_{y}**2 + c2)
mean_xy = self.mean(x*y, kernel)
mean_square_add = self.mean(F.square(x)+F.square(y), kernel)
cs = (2*(mean_xy-mean_x*mean_y)+c2)/(mean_square_add-square_sum+c2)
# SSIM formula
# luminance * cs
ssim = luminance*cs
mean_ssim = P.ReduceMean()(ssim, (-3, -2, -1))
return mean_ssim
def _fspecial_gauss(self, filter_size, filter_sigma):
"""get gauss kernel"""
filter_size, g = _gauss_kernel_helper(filter_size)
square_sigma_scale = -0.5/(filter_sigma * filter_sigma)
g = g*square_sigma_scale
g = F.reshape(g, (1, -1))+F.reshape(g, (-1, 1))
g = F.reshape(g, (1, -1))
g = P.Softmax()(g)
ret = F.reshape(g, (1, 1, filter_size, filter_size))
return ret

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test loss """
""" test image gradients """
import numpy as np
import mindspore.nn as nn
import mindspore.context as context