forked from OSSInnovation/mindspore
!6908 Added raise_not_implement_error functionality in distribution
Merge pull request !6908 from XunDeng/not_impl_error
This commit is contained in:
commit
5171c90048
|
@ -218,6 +218,11 @@ def raise_not_impl_error(name):
|
|||
raise ValueError(
|
||||
f"{name} function should be implemented for non-linear transformation")
|
||||
|
||||
@constexpr
|
||||
def raise_not_implemented_util(func_name, obj, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
f"{func_name} is not implemented for {obj} distribution.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_distribution_name(name, expected_name):
|
||||
|
|
|
@ -19,7 +19,8 @@ from mindspore.nn.cell import Cell
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.common import get_seed
|
||||
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device
|
||||
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
|
||||
raise_not_implemented_util
|
||||
from ._utils.utils import CheckTuple, CheckTensor
|
||||
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
|
||||
|
||||
|
@ -245,6 +246,8 @@ class Distribution(Cell):
|
|||
self._call_prob = self._prob
|
||||
elif hasattr(self, '_log_prob'):
|
||||
self._call_prob = self._calc_prob_from_log_prob
|
||||
else:
|
||||
self._call_prob = self._raise_not_implemented_error('prob')
|
||||
|
||||
def _set_sd(self):
|
||||
"""
|
||||
|
@ -254,6 +257,8 @@ class Distribution(Cell):
|
|||
self._call_sd = self._sd
|
||||
elif hasattr(self, '_var'):
|
||||
self._call_sd = self._calc_sd_from_var
|
||||
else:
|
||||
self._call_sd = self._raise_not_implemented_error('sd')
|
||||
|
||||
def _set_var(self):
|
||||
"""
|
||||
|
@ -263,6 +268,8 @@ class Distribution(Cell):
|
|||
self._call_var = self._var
|
||||
elif hasattr(self, '_sd'):
|
||||
self._call_var = self._calc_var_from_sd
|
||||
else:
|
||||
self._call_var = self._raise_not_implemented_error('var')
|
||||
|
||||
def _set_log_prob(self):
|
||||
"""
|
||||
|
@ -272,6 +279,8 @@ class Distribution(Cell):
|
|||
self._call_log_prob = self._log_prob
|
||||
elif hasattr(self, '_prob'):
|
||||
self._call_log_prob = self._calc_log_prob_from_prob
|
||||
else:
|
||||
self._call_log_prob = self._raise_not_implemented_error('log_prob')
|
||||
|
||||
def _set_cdf(self):
|
||||
"""
|
||||
|
@ -286,13 +295,18 @@ class Distribution(Cell):
|
|||
self._call_cdf = self._calc_cdf_from_survival
|
||||
elif hasattr(self, '_log_survival'):
|
||||
self._call_cdf = self._calc_cdf_from_log_survival
|
||||
else:
|
||||
self._call_cdf = self._raise_not_implemented_error('cdf')
|
||||
|
||||
def _set_survival(self):
|
||||
"""
|
||||
Set survival function based on the availability of _survival function and `_log_survival`
|
||||
and `_call_cdf`.
|
||||
"""
|
||||
if hasattr(self, '_survival_function'):
|
||||
if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or \
|
||||
hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
|
||||
self._call_survival = self._raise_not_implemented_error('survival_function')
|
||||
elif hasattr(self, '_survival_function'):
|
||||
self._call_survival = self._survival_function
|
||||
elif hasattr(self, '_log_survival'):
|
||||
self._call_survival = self._calc_survival_from_log_survival
|
||||
|
@ -303,7 +317,10 @@ class Distribution(Cell):
|
|||
"""
|
||||
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
|
||||
"""
|
||||
if hasattr(self, '_log_cdf'):
|
||||
if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or \
|
||||
hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
|
||||
self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
|
||||
elif hasattr(self, '_log_cdf'):
|
||||
self._call_log_cdf = self._log_cdf
|
||||
elif hasattr(self, '_call_cdf'):
|
||||
self._call_log_cdf = self._calc_log_cdf_from_call_cdf
|
||||
|
@ -312,7 +329,10 @@ class Distribution(Cell):
|
|||
"""
|
||||
Set log survival based on the availability of `_log_survival` and `_call_survival`.
|
||||
"""
|
||||
if hasattr(self, '_log_survival'):
|
||||
if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or \
|
||||
hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
|
||||
self._call_log_survival = self._raise_not_implemented_error('log_cdf')
|
||||
elif hasattr(self, '_log_survival'):
|
||||
self._call_log_survival = self._log_survival
|
||||
elif hasattr(self, '_call_survival'):
|
||||
self._call_log_survival = self._calc_log_survival_from_call_survival
|
||||
|
@ -323,6 +343,14 @@ class Distribution(Cell):
|
|||
"""
|
||||
if hasattr(self, '_cross_entropy'):
|
||||
self._call_cross_entropy = self._cross_entropy
|
||||
else:
|
||||
self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy')
|
||||
|
||||
def _raise_not_implemented_error(self, func_name):
|
||||
name = self.name
|
||||
def raise_error(*args, **kwargs):
|
||||
return raise_not_implemented_util(func_name, name, *args, **kwargs)
|
||||
return raise_error
|
||||
|
||||
def log_prob(self, value, *args, **kwargs):
|
||||
"""
|
||||
|
@ -495,6 +523,9 @@ class Distribution(Cell):
|
|||
"""
|
||||
return self.log_base(self._call_survival(value, *args, **kwargs))
|
||||
|
||||
def _kl_loss(self, *args, **kwargs):
|
||||
return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs)
|
||||
|
||||
def kl_loss(self, dist, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the KL divergence, i.e. KL(a||b).
|
||||
|
@ -510,6 +541,9 @@ class Distribution(Cell):
|
|||
"""
|
||||
return self._kl_loss(dist, *args, **kwargs)
|
||||
|
||||
def _mean(self, *args, **kwargs):
|
||||
return raise_not_implemented_util('mean', self.name, *args, **kwargs)
|
||||
|
||||
def mean(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the mean.
|
||||
|
@ -524,6 +558,9 @@ class Distribution(Cell):
|
|||
"""
|
||||
return self._mean(*args, **kwargs)
|
||||
|
||||
def _mode(self, *args, **kwargs):
|
||||
return raise_not_implemented_util('mode', self.name, *args, **kwargs)
|
||||
|
||||
def mode(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the mode.
|
||||
|
@ -584,6 +621,9 @@ class Distribution(Cell):
|
|||
"""
|
||||
return self.sq_base(self._sd(*args, **kwargs))
|
||||
|
||||
def _entropy(self, *args, **kwargs):
|
||||
return raise_not_implemented_util('entropy', self.name, *args, **kwargs)
|
||||
|
||||
def entropy(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the entropy.
|
||||
|
@ -622,6 +662,9 @@ class Distribution(Cell):
|
|||
"""
|
||||
return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
|
||||
|
||||
def _sample(self, *args, **kwargs):
|
||||
return raise_not_implemented_util('sample', self.name, *args, **kwargs)
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
"""
|
||||
Sampling function.
|
||||
|
@ -680,4 +723,4 @@ class Distribution(Cell):
|
|||
return self._call_cross_entropy(*args, **kwargs)
|
||||
if name == 'sample':
|
||||
return self._sample(*args, **kwargs)
|
||||
return None
|
||||
return raise_not_implemented_util(name, self.name, *args, **kwargs)
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Test nn.probability.distribution.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
func_name_list = ['prob', 'log_prob', 'cdf', 'log_cdf',
|
||||
'survival_function', 'log_survival',
|
||||
'sd', 'var', 'mode', 'mean',
|
||||
'entropy', 'kl_loss', 'cross_entropy',
|
||||
'sample']
|
||||
|
||||
class MyExponential(msd.Distribution):
|
||||
"""
|
||||
Test distirbution class: no function is implemented.
|
||||
"""
|
||||
def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"):
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'rate': rate}
|
||||
super(MyExponential, self).__init__(seed, dtype, name, param)
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test Net: function called through construct.
|
||||
"""
|
||||
def __init__(self, func_name):
|
||||
super(Net, self).__init__()
|
||||
self.dist = MyExponential()
|
||||
self.name = func_name
|
||||
|
||||
def construct(self, *args, **kwargs):
|
||||
return self.dist(self.name, *args, **kwargs)
|
||||
|
||||
|
||||
def test_raise_not_implemented_error_construct():
|
||||
"""
|
||||
test raise not implemented error in pynative mode.
|
||||
"""
|
||||
value = Tensor([0.2], dtype=mstype.float32)
|
||||
for func_name in func_name_list:
|
||||
with pytest.raises(NotImplementedError):
|
||||
net = Net(func_name)
|
||||
net(value)
|
||||
|
||||
def test_raise_not_implemented_error_construct_graph_mode():
|
||||
"""
|
||||
test raise not implemented error in graph mode.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
value = Tensor([0.2], dtype=mstype.float32)
|
||||
for func_name in func_name_list:
|
||||
with pytest.raises(NotImplementedError):
|
||||
net = Net(func_name)
|
||||
net(value)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
"""
|
||||
Test Net: function called directly.
|
||||
"""
|
||||
def __init__(self, func_name):
|
||||
super(Net1, self).__init__()
|
||||
self.dist = MyExponential()
|
||||
self.func = getattr(self.dist, func_name)
|
||||
|
||||
def construct(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def test_raise_not_implemented_error():
|
||||
"""
|
||||
test raise not implemented error in pynative mode.
|
||||
"""
|
||||
value = Tensor([0.2], dtype=mstype.float32)
|
||||
for func_name in func_name_list:
|
||||
with pytest.raises(NotImplementedError):
|
||||
net = Net1(func_name)
|
||||
net(value)
|
||||
|
||||
def test_raise_not_implemented_error_graph_mode():
|
||||
"""
|
||||
test raise not implemented error in graph mode.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
value = Tensor([0.2], dtype=mstype.float32)
|
||||
for func_name in func_name_list:
|
||||
with pytest.raises(NotImplementedError):
|
||||
net = Net1(func_name)
|
||||
net(value)
|
Loading…
Reference in New Issue