forked from mindspore-Ecosystem/mindspore
update SSIM loss, add MSSSIM loss feature; add their ut testcases.
This commit is contained in:
parent
03ef509ee9
commit
9b21420b3e
|
@ -21,9 +21,13 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .conv import Conv2d
|
||||
from .container import CellList
|
||||
from .pooling import AvgPool2d
|
||||
from .activation import ReLU
|
||||
from ..cell import Cell
|
||||
|
||||
__all__ = ['ImageGradients', 'SSIM', 'PSNR', 'CentralCrop']
|
||||
__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop']
|
||||
|
||||
class ImageGradients(Cell):
|
||||
r"""
|
||||
|
@ -83,21 +87,6 @@ def _convert_img_dtype_to_float32(img, max_val):
|
|||
ret = ret * scale
|
||||
return ret
|
||||
|
||||
|
||||
@constexpr
|
||||
def _gauss_kernel_helper(filter_size):
|
||||
"""gauss kernel helper"""
|
||||
filter_size = F.scalar_cast(filter_size, mstype.int32)
|
||||
coords = ()
|
||||
for i in range(filter_size):
|
||||
i_cast = F.scalar_cast(i, mstype.float32)
|
||||
offset = F.scalar_cast(filter_size-1, mstype.float32)/2.0
|
||||
element = i_cast-offset
|
||||
coords = coords+(element,)
|
||||
g = np.square(coords).astype(np.float32)
|
||||
g = Tensor(g)
|
||||
return filter_size, g
|
||||
|
||||
@constexpr
|
||||
def _check_input_4d(input_shape, param_name, func_name):
|
||||
if len(input_shape) != 4:
|
||||
|
@ -110,9 +99,65 @@ def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
|
|||
validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name)
|
||||
validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0):
|
||||
return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
weight_init=weight, padding=padding, pad_mode="valid")
|
||||
|
||||
def _create_window(size, sigma):
|
||||
x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
|
||||
x_data = np.expand_dims(x_data, axis=-1).astype(np.float32)
|
||||
x_data = np.expand_dims(x_data, axis=-1) ** 2
|
||||
y_data = np.expand_dims(y_data, axis=-1).astype(np.float32)
|
||||
y_data = np.expand_dims(y_data, axis=-1) ** 2
|
||||
sigma = 2 * sigma ** 2
|
||||
g = np.exp(-(x_data + y_data) / sigma)
|
||||
return np.transpose(g / np.sum(g), (2, 3, 0, 1))
|
||||
|
||||
def _split_img(x):
|
||||
_, c, _, _ = F.shape(x)
|
||||
img_split = P.Split(1, c)
|
||||
output = img_split(x)
|
||||
return output, c
|
||||
|
||||
def _compute_per_channel_loss(c1, c2, img1, img2, conv):
|
||||
"""computes ssim index between img1 and img2 per single channel"""
|
||||
dot_img = img1 * img2
|
||||
mu1 = conv(img1)
|
||||
mu2 = conv(img2)
|
||||
mu1_sq = mu1 * mu1
|
||||
mu2_sq = mu2 * mu2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_tmp = conv(img1 * img1)
|
||||
sigma1_sq = sigma1_tmp - mu1_sq
|
||||
sigma2_tmp = conv(img2 * img2)
|
||||
sigma2_sq = sigma2_tmp - mu2_sq
|
||||
sigma12_tmp = conv(dot_img)
|
||||
sigma12 = sigma12_tmp - mu1_mu2
|
||||
a = (2 * mu1_mu2 + c1)
|
||||
b = (mu1_sq + mu2_sq + c1)
|
||||
v1 = 2 * sigma12 + c2
|
||||
v2 = sigma1_sq + sigma2_sq + c2
|
||||
ssim = (a * v1) / (b * v2)
|
||||
cs = v1 / v2
|
||||
return ssim, cs
|
||||
|
||||
def _compute_multi_channel_loss(c1, c2, img1, img2, conv, concat, mean):
|
||||
"""computes ssim index between img1 and img2 per color channel"""
|
||||
split_img1, c = _split_img(img1)
|
||||
split_img2, _ = _split_img(img2)
|
||||
multi_ssim = ()
|
||||
multi_cs = ()
|
||||
for i in range(c):
|
||||
ssim_per_channel, cs_per_channel = _compute_per_channel_loss(c1, c2, split_img1[i], split_img2[i], conv)
|
||||
multi_ssim += (ssim_per_channel,)
|
||||
multi_cs += (cs_per_channel,)
|
||||
|
||||
multi_ssim = concat(multi_ssim)
|
||||
multi_cs = concat(multi_cs)
|
||||
|
||||
ssim = mean(multi_ssim, (2, 3))
|
||||
cs = mean(multi_cs, (2, 3))
|
||||
return ssim, cs
|
||||
|
||||
class SSIM(Cell):
|
||||
r"""
|
||||
|
@ -157,67 +202,126 @@ class SSIM(Cell):
|
|||
self.max_val = max_val
|
||||
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
|
||||
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
|
||||
validator.check_value_type('k1', k1, [float], self.cls_name)
|
||||
self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
|
||||
validator.check_value_type('k2', k2, [float], self.cls_name)
|
||||
self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
|
||||
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
|
||||
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
|
||||
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
|
||||
window = _create_window(filter_size, filter_sigma)
|
||||
self.conv = _conv2d(1, 1, filter_size, Tensor(window))
|
||||
self.conv.weight.requires_grad = False
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, img1, img2):
|
||||
_check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
|
||||
P.SameTypeShape()(img1, img2)
|
||||
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
||||
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
|
||||
|
||||
kernel = self._fspecial_gauss(self.filter_size, self.filter_sigma)
|
||||
kernel = P.Tile()(kernel, (1, P.Shape()(img1)[1], 1, 1))
|
||||
c1 = (self.k1 * max_val) ** 2
|
||||
c2 = (self.k2 * max_val) ** 2
|
||||
|
||||
mean_ssim = self._calculate_mean_ssim(img1, img2, kernel, max_val, self.k1, self.k2)
|
||||
ssim_ave_channel, _ = _compute_multi_channel_loss(c1, c2, img1, img2, self.conv, self.concat, self.reduce_mean)
|
||||
loss = self.reduce_mean(ssim_ave_channel, -1)
|
||||
|
||||
return mean_ssim
|
||||
return loss
|
||||
|
||||
def _calculate_mean_ssim(self, x, y, kernel, max_val, k1, k2):
|
||||
"""calculate mean ssim"""
|
||||
c1 = (k1 * max_val) * (k1 * max_val)
|
||||
c2 = (k2 * max_val) * (k2 * max_val)
|
||||
def _downsample(img1, img2, op):
|
||||
a = op(img1)
|
||||
b = op(img2)
|
||||
return a, b
|
||||
|
||||
# SSIM luminance formula
|
||||
# (2 * mean_{x} * mean_{y} + c1) / (mean_{x}**2 + mean_{y}**2 + c1)
|
||||
mean_x = self.mean(x, kernel)
|
||||
mean_y = self.mean(y, kernel)
|
||||
square_sum = F.square(mean_x)+F.square(mean_y)
|
||||
luminance = (2*mean_x*mean_y+c1)/(square_sum+c1)
|
||||
class MSSSIM(Cell):
|
||||
r"""
|
||||
Returns MS-SSIM index between img1 and img2.
|
||||
|
||||
# SSIM contrast*structure formula (when c3 = c2/2)
|
||||
# (2 * conv_{xy} + c2) / (conv_{xx} + conv_{yy} + c2), equals to
|
||||
# (2 * (mean_{xy} - mean_{x}*mean_{y}) + c2) / (mean_{xx}-mean_{x}**2 + mean_{yy}-mean_{y}**2 + c2)
|
||||
mean_xy = self.mean(x*y, kernel)
|
||||
mean_square_add = self.mean(F.square(x)+F.square(y), kernel)
|
||||
Its implementation is based on Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. `Multiscale structural similarity
|
||||
for image quality assessment <https://ieeexplore.ieee.org/document/1292216>`_.
|
||||
Signals, Systems and Computers, 2004.
|
||||
|
||||
cs = (2*(mean_xy-mean_x*mean_y)+c2)/(mean_square_add-square_sum+c2)
|
||||
.. math::
|
||||
|
||||
# SSIM formula
|
||||
# luminance * cs
|
||||
ssim = luminance*cs
|
||||
l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
|
||||
c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
|
||||
s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
|
||||
MSSSIM(x,y)&=l^alpha_M*{\prod_{1\leq j\leq M} (c^beta_j*s^gamma_j)}.
|
||||
|
||||
mean_ssim = P.ReduceMean()(ssim, (-3, -2, -1))
|
||||
Args:
|
||||
max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
|
||||
Default: 1.0.
|
||||
power_factors (Union[tuple, list]): Iterable of weights for each of the scales.
|
||||
Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al.
|
||||
filter_size (int): The size of the Gaussian filter. Default: 11.
|
||||
filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
|
||||
k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
|
||||
k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
|
||||
|
||||
return mean_ssim
|
||||
Inputs:
|
||||
- **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2.
|
||||
- **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1.
|
||||
|
||||
def _fspecial_gauss(self, filter_size, filter_sigma):
|
||||
"""get gauss kernel"""
|
||||
filter_size, g = _gauss_kernel_helper(filter_size)
|
||||
Outputs:
|
||||
Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.
|
||||
|
||||
square_sigma_scale = -0.5/(filter_sigma * filter_sigma)
|
||||
g = g*square_sigma_scale
|
||||
g = F.reshape(g, (1, -1))+F.reshape(g, (-1, 1))
|
||||
g = F.reshape(g, (1, -1))
|
||||
g = P.Softmax()(g)
|
||||
ret = F.reshape(g, (1, 1, filter_size, filter_size))
|
||||
return ret
|
||||
Examples:
|
||||
>>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033))
|
||||
>>> img1 = Tensor(np.random.random((1,3,128,128)))
|
||||
>>> img2 = Tensor(np.random.random((1,3,128,128)))
|
||||
>>> msssim = net(img1, img2)
|
||||
"""
|
||||
def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11,
|
||||
filter_sigma=1.5, k1=0.01, k2=0.03):
|
||||
super(MSSSIM, self).__init__()
|
||||
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
|
||||
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
|
||||
self.max_val = max_val
|
||||
validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
|
||||
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
|
||||
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
|
||||
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
|
||||
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
|
||||
window = _create_window(filter_size, filter_sigma)
|
||||
self.level = len(power_factors)
|
||||
self.conv = []
|
||||
for i in range(self.level):
|
||||
self.conv.append(_conv2d(1, 1, filter_size, Tensor(window)))
|
||||
self.conv[i].weight.requires_grad = False
|
||||
self.multi_convs_list = CellList(self.conv)
|
||||
self.weight_tensor = Tensor(power_factors, mstype.float32)
|
||||
self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||
self.relu = ReLU()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.prod = P.ReduceProd()
|
||||
self.pow = P.Pow()
|
||||
self.pack = P.Pack(axis=-1)
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, img1, img2):
|
||||
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
||||
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
||||
P.SameTypeShape()(img1, img2)
|
||||
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
||||
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
|
||||
|
||||
c1 = (self.k1 * max_val) ** 2
|
||||
c2 = (self.k2 * max_val) ** 2
|
||||
|
||||
sim = ()
|
||||
mcs = ()
|
||||
|
||||
for i in range(self.level):
|
||||
sim, cs = _compute_multi_channel_loss(c1, c2, img1, img2,
|
||||
self.multi_convs_list[i], self.concat, self.reduce_mean)
|
||||
mcs += (self.relu(cs),)
|
||||
img1, img2 = _downsample(img1, img2, self.avg_pool)
|
||||
|
||||
mcs = mcs[0:-1:1]
|
||||
mcs_and_ssim = self.pack(mcs + (self.relu(sim),))
|
||||
mcs_and_ssim = self.pow(mcs_and_ssim, self.weight_tensor)
|
||||
ms_ssim = self.prod(mcs_and_ssim, -1)
|
||||
loss = self.reduce_mean(ms_ssim, -1)
|
||||
|
||||
return loss
|
||||
|
||||
class PSNR(Cell):
|
||||
r"""
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
test msssim
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
_MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
|
||||
|
||||
class MSSSIMNet(nn.Cell):
|
||||
def __init__(self, max_val=1.0, power_factors=_MSSSIM_WEIGHTS, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
|
||||
super(MSSSIMNet, self).__init__()
|
||||
self.net = nn.MSSSIM(max_val, power_factors, filter_size, filter_sigma, k1, k2)
|
||||
|
||||
def construct(self, img1, img2):
|
||||
return self.net(img1, img2)
|
||||
|
||||
|
||||
def test_compile():
|
||||
factors = (0.033, 0.033, 0.033)
|
||||
net = MSSSIMNet(power_factors=factors)
|
||||
img1 = Tensor(np.random.random((8, 3, 128, 128)))
|
||||
img2 = Tensor(np.random.random((8, 3, 128, 128)))
|
||||
_executor.compile(net, img1, img2)
|
||||
|
||||
|
||||
def test_compile_grayscale():
|
||||
max_val = 255
|
||||
factors = (0.033, 0.033, 0.033)
|
||||
net = MSSSIMNet(max_val=max_val, power_factors=factors)
|
||||
img1 = Tensor(np.random.randint(0, 256, (8, 3, 128, 128), np.uint8))
|
||||
img2 = Tensor(np.random.randint(0, 256, (8, 3, 128, 128), np.uint8))
|
||||
_executor.compile(net, img1, img2)
|
||||
|
||||
|
||||
def test_msssim_max_val_negative():
|
||||
max_val = -1
|
||||
with pytest.raises(ValueError):
|
||||
_ = MSSSIMNet(max_val)
|
||||
|
||||
|
||||
def test_msssim_max_val_bool():
|
||||
max_val = True
|
||||
with pytest.raises(TypeError):
|
||||
_ = MSSSIMNet(max_val)
|
||||
|
||||
|
||||
def test_msssim_max_val_zero():
|
||||
max_val = 0
|
||||
with pytest.raises(ValueError):
|
||||
_ = MSSSIMNet(max_val)
|
||||
|
||||
|
||||
def test_msssim_power_factors_set():
|
||||
with pytest.raises(TypeError):
|
||||
_ = MSSSIMNet(power_factors={0.033, 0.033, 0.033})
|
||||
|
||||
|
||||
def test_msssim_filter_size_float():
|
||||
with pytest.raises(TypeError):
|
||||
_ = MSSSIMNet(filter_size=1.1)
|
||||
|
||||
|
||||
def test_msssim_filter_size_zero():
|
||||
with pytest.raises(ValueError):
|
||||
_ = MSSSIMNet(filter_size=0)
|
||||
|
||||
|
||||
def test_msssim_filter_sigma_zero():
|
||||
with pytest.raises(ValueError):
|
||||
_ = MSSSIMNet(filter_sigma=0.0)
|
||||
|
||||
|
||||
def test_msssim_filter_sigma_negative():
|
||||
with pytest.raises(ValueError):
|
||||
_ = MSSSIMNet(filter_sigma=-0.1)
|
||||
|
||||
|
||||
def test_msssim_different_shape():
|
||||
shape_1 = (8, 3, 128, 128)
|
||||
shape_2 = (8, 3, 256, 256)
|
||||
factors = (0.033, 0.033, 0.033)
|
||||
img1 = Tensor(np.random.random(shape_1))
|
||||
img2 = Tensor(np.random.random(shape_2))
|
||||
net = MSSSIMNet(power_factors=factors)
|
||||
with pytest.raises(ValueError):
|
||||
_executor.compile(net, img1, img2)
|
||||
|
||||
|
||||
def test_msssim_different_dtype():
|
||||
dtype_1 = mstype.float32
|
||||
dtype_2 = mstype.float16
|
||||
factors = (0.033, 0.033, 0.033)
|
||||
img1 = Tensor(np.random.random((8, 3, 128, 128)), dtype=dtype_1)
|
||||
img2 = Tensor(np.random.random((8, 3, 128, 128)), dtype=dtype_2)
|
||||
net = MSSSIMNet(power_factors=factors)
|
||||
with pytest.raises(TypeError):
|
||||
_executor.compile(net, img1, img2)
|
||||
|
||||
|
||||
def test_msssim_invalid_5d_input():
|
||||
shape_1 = (8, 3, 128, 128)
|
||||
shape_2 = (8, 3, 256, 256)
|
||||
invalid_shape = (8, 3, 128, 128, 1)
|
||||
factors = (0.033, 0.033, 0.033)
|
||||
img1 = Tensor(np.random.random(shape_1))
|
||||
invalid_img1 = Tensor(np.random.random(invalid_shape))
|
||||
img2 = Tensor(np.random.random(shape_2))
|
||||
invalid_img2 = Tensor(np.random.random(invalid_shape))
|
||||
|
||||
net = MSSSIMNet(power_factors=factors)
|
||||
with pytest.raises(ValueError):
|
||||
_executor.compile(net, invalid_img1, img2)
|
||||
with pytest.raises(ValueError):
|
||||
_executor.compile(net, img1, invalid_img2)
|
||||
with pytest.raises(ValueError):
|
||||
_executor.compile(net, invalid_img1, invalid_img2)
|
|
@ -78,26 +78,6 @@ def test_ssim_filter_sigma_negative():
|
|||
_ = SSIMNet(filter_sigma=-0.1)
|
||||
|
||||
|
||||
def test_ssim_k1_k2_wrong_value():
|
||||
with pytest.raises(ValueError):
|
||||
_ = SSIMNet(k1=1.1)
|
||||
with pytest.raises(ValueError):
|
||||
_ = SSIMNet(k1=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
_ = SSIMNet(k1=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
_ = SSIMNet(k1=-1.0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = SSIMNet(k2=1.1)
|
||||
with pytest.raises(ValueError):
|
||||
_ = SSIMNet(k2=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
_ = SSIMNet(k2=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
_ = SSIMNet(k2=-1.0)
|
||||
|
||||
|
||||
def test_ssim_different_shape():
|
||||
shape_1 = (8, 3, 16, 16)
|
||||
shape_2 = (8, 3, 8, 8)
|
||||
|
|
Loading…
Reference in New Issue