add safe_normalize

This commit is contained in:
zhujingxuan 2021-11-03 16:28:29 +08:00
parent cd44957b2d
commit f128914f33
2 changed files with 70 additions and 1 deletions

View File

@ -14,7 +14,9 @@
# ============================================================================
"""internal utility functions"""
import numpy as onp
from .. import nn, ops
from ..numpy import where, isnan, zeros_like
from ..ops import functional as F
from ..common import Tensor
from ..common import dtype as mstype
from .utils_const import _type_convert, _raise_type_error
@ -58,6 +60,30 @@ def _to_scalar(arr):
raise ValueError("{} are not supported.".format(type(arr)))
class _SafeNormalize(nn.Cell):
"""Normalize method that cast very small results to zero."""
def __init__(self):
"""Initialize LineSearch."""
super(_SafeNormalize, self).__init__()
self.eps = ops.Eps()
def construct(self, x, threshold=None):
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
norm = F.pows(x_sum2, 1./2.0)
if threshold is None:
if x.dtype in mstype.float_type:
# pick the first element of x to get the eps
threshold = self.eps(x[(0,) * x.ndim])
else:
threshold = 0
normalized_x = where(norm > threshold, x / norm, zeros_like(x))
normalized_x = where(isnan(normalized_x), 0, normalized_x)
return normalized_x, norm
_safe_normalize = _SafeNormalize()
_FLOAT_ONE = _to_tensor(1.0)
_FLOAT_ZERO = _to_tensor(0.0)
_INT_ZERO = _to_tensor(0)

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""st for scipy.utils"""
import pytest
import numpy as onp
from mindspore import context, Tensor
from mindspore.scipy.utils import _safe_normalize
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('shape', [(10,), (10, 1)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_safe_normalize(mode, shape, dtype):
"""
Feature: ALL TO ALL
Description: test cases for _safe_normalize
Expectation: the result match scipy
"""
context.set_context(mode=mode)
x = onp.random.random(shape).astype(dtype)
normalized_x, x_norm = _safe_normalize(Tensor(x))
normalized_x = normalized_x.asnumpy()
x_norm = x_norm.asnumpy()
assert onp.allclose(onp.sum(normalized_x ** 2), 1)
assert onp.allclose(x / x_norm, normalized_x)