!6908 Added raise_not_implement_error functionality in distribution

Merge pull request !6908 from XunDeng/not_impl_error
This commit is contained in:
mindspore-ci-bot 2020-09-27 14:54:59 +08:00 committed by Gitee
commit 5171c90048
3 changed files with 155 additions and 5 deletions

View File

@ -218,6 +218,11 @@ def raise_not_impl_error(name):
raise ValueError(
f"{name} function should be implemented for non-linear transformation")
@constexpr
def raise_not_implemented_util(func_name, obj, *args, **kwargs):
raise NotImplementedError(
f"{func_name} is not implemented for {obj} distribution.")
@constexpr
def check_distribution_name(name, expected_name):

View File

@ -19,7 +19,8 @@ from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common import get_seed
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
raise_not_implemented_util
from ._utils.utils import CheckTuple, CheckTensor
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
@ -245,6 +246,8 @@ class Distribution(Cell):
self._call_prob = self._prob
elif hasattr(self, '_log_prob'):
self._call_prob = self._calc_prob_from_log_prob
else:
self._call_prob = self._raise_not_implemented_error('prob')
def _set_sd(self):
"""
@ -254,6 +257,8 @@ class Distribution(Cell):
self._call_sd = self._sd
elif hasattr(self, '_var'):
self._call_sd = self._calc_sd_from_var
else:
self._call_sd = self._raise_not_implemented_error('sd')
def _set_var(self):
"""
@ -263,6 +268,8 @@ class Distribution(Cell):
self._call_var = self._var
elif hasattr(self, '_sd'):
self._call_var = self._calc_var_from_sd
else:
self._call_var = self._raise_not_implemented_error('var')
def _set_log_prob(self):
"""
@ -272,6 +279,8 @@ class Distribution(Cell):
self._call_log_prob = self._log_prob
elif hasattr(self, '_prob'):
self._call_log_prob = self._calc_log_prob_from_prob
else:
self._call_log_prob = self._raise_not_implemented_error('log_prob')
def _set_cdf(self):
"""
@ -286,13 +295,18 @@ class Distribution(Cell):
self._call_cdf = self._calc_cdf_from_survival
elif hasattr(self, '_log_survival'):
self._call_cdf = self._calc_cdf_from_log_survival
else:
self._call_cdf = self._raise_not_implemented_error('cdf')
def _set_survival(self):
"""
Set survival function based on the availability of _survival function and `_log_survival`
and `_call_cdf`.
"""
if hasattr(self, '_survival_function'):
if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or \
hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
self._call_survival = self._raise_not_implemented_error('survival_function')
elif hasattr(self, '_survival_function'):
self._call_survival = self._survival_function
elif hasattr(self, '_log_survival'):
self._call_survival = self._calc_survival_from_log_survival
@ -303,7 +317,10 @@ class Distribution(Cell):
"""
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
"""
if hasattr(self, '_log_cdf'):
if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or \
hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
elif hasattr(self, '_log_cdf'):
self._call_log_cdf = self._log_cdf
elif hasattr(self, '_call_cdf'):
self._call_log_cdf = self._calc_log_cdf_from_call_cdf
@ -312,7 +329,10 @@ class Distribution(Cell):
"""
Set log survival based on the availability of `_log_survival` and `_call_survival`.
"""
if hasattr(self, '_log_survival'):
if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or \
hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
self._call_log_survival = self._raise_not_implemented_error('log_cdf')
elif hasattr(self, '_log_survival'):
self._call_log_survival = self._log_survival
elif hasattr(self, '_call_survival'):
self._call_log_survival = self._calc_log_survival_from_call_survival
@ -323,6 +343,14 @@ class Distribution(Cell):
"""
if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy
else:
self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy')
def _raise_not_implemented_error(self, func_name):
name = self.name
def raise_error(*args, **kwargs):
return raise_not_implemented_util(func_name, name, *args, **kwargs)
return raise_error
def log_prob(self, value, *args, **kwargs):
"""
@ -495,6 +523,9 @@ class Distribution(Cell):
"""
return self.log_base(self._call_survival(value, *args, **kwargs))
def _kl_loss(self, *args, **kwargs):
return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs)
def kl_loss(self, dist, *args, **kwargs):
"""
Evaluate the KL divergence, i.e. KL(a||b).
@ -510,6 +541,9 @@ class Distribution(Cell):
"""
return self._kl_loss(dist, *args, **kwargs)
def _mean(self, *args, **kwargs):
return raise_not_implemented_util('mean', self.name, *args, **kwargs)
def mean(self, *args, **kwargs):
"""
Evaluate the mean.
@ -524,6 +558,9 @@ class Distribution(Cell):
"""
return self._mean(*args, **kwargs)
def _mode(self, *args, **kwargs):
return raise_not_implemented_util('mode', self.name, *args, **kwargs)
def mode(self, *args, **kwargs):
"""
Evaluate the mode.
@ -584,6 +621,9 @@ class Distribution(Cell):
"""
return self.sq_base(self._sd(*args, **kwargs))
def _entropy(self, *args, **kwargs):
return raise_not_implemented_util('entropy', self.name, *args, **kwargs)
def entropy(self, *args, **kwargs):
"""
Evaluate the entropy.
@ -622,6 +662,9 @@ class Distribution(Cell):
"""
return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
def _sample(self, *args, **kwargs):
return raise_not_implemented_util('sample', self.name, *args, **kwargs)
def sample(self, *args, **kwargs):
"""
Sampling function.
@ -680,4 +723,4 @@ class Distribution(Cell):
return self._call_cross_entropy(*args, **kwargs)
if name == 'sample':
return self._sample(*args, **kwargs)
return None
return raise_not_implemented_util(name, self.name, *args, **kwargs)

View File

@ -0,0 +1,102 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test nn.probability.distribution.
"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype as mstype
from mindspore import Tensor
from mindspore import context
func_name_list = ['prob', 'log_prob', 'cdf', 'log_cdf',
'survival_function', 'log_survival',
'sd', 'var', 'mode', 'mean',
'entropy', 'kl_loss', 'cross_entropy',
'sample']
class MyExponential(msd.Distribution):
"""
Test distirbution class: no function is implemented.
"""
def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"):
param = dict(locals())
param['param_dict'] = {'rate': rate}
super(MyExponential, self).__init__(seed, dtype, name, param)
class Net(nn.Cell):
"""
Test Net: function called through construct.
"""
def __init__(self, func_name):
super(Net, self).__init__()
self.dist = MyExponential()
self.name = func_name
def construct(self, *args, **kwargs):
return self.dist(self.name, *args, **kwargs)
def test_raise_not_implemented_error_construct():
"""
test raise not implemented error in pynative mode.
"""
value = Tensor([0.2], dtype=mstype.float32)
for func_name in func_name_list:
with pytest.raises(NotImplementedError):
net = Net(func_name)
net(value)
def test_raise_not_implemented_error_construct_graph_mode():
"""
test raise not implemented error in graph mode.
"""
context.set_context(mode=context.GRAPH_MODE)
value = Tensor([0.2], dtype=mstype.float32)
for func_name in func_name_list:
with pytest.raises(NotImplementedError):
net = Net(func_name)
net(value)
class Net1(nn.Cell):
"""
Test Net: function called directly.
"""
def __init__(self, func_name):
super(Net1, self).__init__()
self.dist = MyExponential()
self.func = getattr(self.dist, func_name)
def construct(self, *args, **kwargs):
return self.func(*args, **kwargs)
def test_raise_not_implemented_error():
"""
test raise not implemented error in pynative mode.
"""
value = Tensor([0.2], dtype=mstype.float32)
for func_name in func_name_list:
with pytest.raises(NotImplementedError):
net = Net1(func_name)
net(value)
def test_raise_not_implemented_error_graph_mode():
"""
test raise not implemented error in graph mode.
"""
context.set_context(mode=context.GRAPH_MODE)
value = Tensor([0.2], dtype=mstype.float32)
for func_name in func_name_list:
with pytest.raises(NotImplementedError):
net = Net1(func_name)
net(value)