forked from mindspore-Ecosystem/mindspore
fix mssim precision when dtype is uint32.
This commit is contained in:
parent
4f754daccf
commit
80d2214361
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""image"""
|
||||
import numbers
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -93,6 +94,16 @@ def _convert_img_dtype_to_float32(img, max_val):
|
|||
ret = ret * scale
|
||||
return ret
|
||||
|
||||
@constexpr
|
||||
def _get_dtype_max(dtype):
|
||||
"""get max of the dtype"""
|
||||
np_type = mstype.dtype_to_nptype(dtype)
|
||||
if issubclass(np_type, numbers.Integral):
|
||||
dtype_max = np.float64(np.iinfo(np_type).max)
|
||||
else:
|
||||
dtype_max = 1.0
|
||||
return dtype_max
|
||||
|
||||
@constexpr
|
||||
def _check_input_4d(input_shape, param_name, func_name):
|
||||
if len(input_shape) != 4:
|
||||
|
@ -224,9 +235,11 @@ class SSIM(Cell):
|
|||
_check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
|
||||
P.SameTypeShape()(img1, img2)
|
||||
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
||||
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
|
||||
dtype_max_val = _get_dtype_max(F.dtype(img1))
|
||||
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
|
||||
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
|
||||
img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
|
||||
|
||||
c1 = (self.k1 * max_val) ** 2
|
||||
c2 = (self.k2 * max_val) ** 2
|
||||
|
@ -309,10 +322,13 @@ class MSSSIM(Cell):
|
|||
def construct(self, img1, img2):
|
||||
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
||||
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
||||
_check_input_dtype(F.dtype(img1), 'img1', mstype.number_type, self.cls_name)
|
||||
P.SameTypeShape()(img1, img2)
|
||||
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
||||
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
|
||||
dtype_max_val = _get_dtype_max(F.dtype(img1))
|
||||
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
|
||||
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
|
||||
img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
|
||||
|
||||
c1 = (self.k1 * max_val) ** 2
|
||||
c2 = (self.k2 * max_val) ** 2
|
||||
|
@ -375,9 +391,11 @@ class PSNR(Cell):
|
|||
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
||||
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
||||
P.SameTypeShape()(img1, img2)
|
||||
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
||||
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
|
||||
dtype_max_val = _get_dtype_max(F.dtype(img1))
|
||||
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
|
||||
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
|
||||
img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
|
||||
|
||||
mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1))
|
||||
psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0)
|
||||
|
|
Loading…
Reference in New Issue